From 79ea94483ef8aa2bd30f088acd0f474ee3c93d16 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Thu, 31 Jul 2025 19:43:04 +0900 Subject: [PATCH] refine some orm types (#22885) --- api/models/account.py | 94 ++++---- api/models/api_based_extension.py | 13 +- api/models/dataset.py | 220 +++++++++--------- api/models/model.py | 196 ++++++++-------- api/models/provider.py | 86 +++---- api/models/source.py | 26 ++- api/models/task.py | 19 +- api/models/tools.py | 100 ++++---- api/models/web.py | 12 +- api/models/workflow.py | 62 ++--- api/services/dataset_service.py | 5 + .../batch_create_segment_to_index_task.py | 1 + 12 files changed, 424 insertions(+), 410 deletions(-) diff --git a/api/models/account.py b/api/models/account.py index d63c5d7fb..343705589 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Optional, cast from flask_login import UserMixin # type: ignore -from sqlalchemy import func, select +from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, mapped_column, reconstructor from models.base import Base @@ -86,23 +86,21 @@ class Account(UserMixin, Base): __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name: Mapped[str] = mapped_column(db.String(255)) - email: Mapped[str] = mapped_column(db.String(255)) - password: Mapped[Optional[str]] = mapped_column(db.String(255)) - password_salt: Mapped[Optional[str]] = mapped_column(db.String(255)) - avatar: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - interface_language: Mapped[Optional[str]] = mapped_column(db.String(255)) - interface_theme: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - timezone: Mapped[Optional[str]] = mapped_column(db.String(255)) - last_login_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - last_login_ip: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - last_active_at: Mapped[datetime] = mapped_column( - db.DateTime, server_default=func.current_timestamp(), nullable=False - ) - status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'active'::character varying")) - initialized_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) + name: Mapped[str] = mapped_column(String(255)) + email: Mapped[str] = mapped_column(String(255)) + password: Mapped[Optional[str]] = mapped_column(String(255)) + password_salt: Mapped[Optional[str]] = mapped_column(String(255)) + avatar: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + interface_language: Mapped[Optional[str]] = mapped_column(String(255)) + interface_theme: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + timezone: Mapped[Optional[str]] = mapped_column(String(255)) + last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + status: Mapped[str] = mapped_column(String(16), server_default=db.text("'active'::character varying")) + initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) @reconstructor def init_on_load(self): @@ -200,13 +198,13 @@ class Tenant(Base): __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name: Mapped[str] = mapped_column(db.String(255)) + name: Mapped[str] = mapped_column(String(255)) encrypt_public_key = db.Column(db.Text) - plan: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'basic'::character varying")) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying")) + plan: Mapped[str] = mapped_column(String(255), server_default=db.text("'basic'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=db.text("'normal'::character varying")) custom_config: Mapped[Optional[str]] = mapped_column(db.Text) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: return ( @@ -237,10 +235,10 @@ class TenantAccountJoin(Base): tenant_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID) current: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) - role: Mapped[str] = mapped_column(db.String(16), server_default="normal") + role: Mapped[str] = mapped_column(String(16), server_default="normal") invited_by: Mapped[Optional[str]] = mapped_column(StringUUID) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) class AccountIntegrate(Base): @@ -253,11 +251,11 @@ class AccountIntegrate(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) account_id: Mapped[str] = mapped_column(StringUUID) - provider: Mapped[str] = mapped_column(db.String(16)) - open_id: Mapped[str] = mapped_column(db.String(255)) - encrypted_token: Mapped[str] = mapped_column(db.String(255)) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + provider: Mapped[str] = mapped_column(String(16)) + open_id: Mapped[str] = mapped_column(String(255)) + encrypted_token: Mapped[str] = mapped_column(String(255)) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) class InvitationCode(Base): @@ -269,14 +267,14 @@ class InvitationCode(Base): ) id: Mapped[int] = mapped_column(db.Integer) - batch: Mapped[str] = mapped_column(db.String(255)) - code: Mapped[str] = mapped_column(db.String(32)) - status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'unused'::character varying")) - used_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + batch: Mapped[str] = mapped_column(String(255)) + code: Mapped[str] = mapped_column(String(32)) + status: Mapped[str] = mapped_column(String(16), server_default=db.text("'unused'::character varying")) + used_at: Mapped[Optional[datetime]] = mapped_column(DateTime) used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID) used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) - deprecated_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)")) + deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TenantPluginPermission(Base): @@ -298,10 +296,8 @@ class TenantPluginPermission(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - install_permission: Mapped[InstallPermission] = mapped_column( - db.String(16), nullable=False, server_default="everyone" - ) - debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone") + install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone") + debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone") class TenantPluginAutoUpgradeStrategy(Base): @@ -323,14 +319,10 @@ class TenantPluginAutoUpgradeStrategy(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - strategy_setting: Mapped[StrategySetting] = mapped_column(db.String(16), nullable=False, server_default="fix_only") + strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only") upgrade_time_of_day: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) # seconds of the day - upgrade_mode: Mapped[UpgradeMode] = mapped_column(db.String(16), nullable=False, server_default="exclude") - exclude_plugins: Mapped[list[str]] = mapped_column( - db.ARRAY(db.String(255)), nullable=False - ) # plugin_id (author/name) - include_plugins: Mapped[list[str]] = mapped_column( - db.ARRAY(db.String(255)), nullable=False - ) # plugin_id (author/name) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude") + exclude_plugins: Mapped[list[str]] = mapped_column(db.ARRAY(String(255)), nullable=False) # plugin_id (author/name) + include_plugins: Mapped[list[str]] = mapped_column(db.ARRAY(String(255)), nullable=False) # plugin_id (author/name) + created_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 3cef5a0fb..ac9eda682 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,7 +1,8 @@ import enum +from datetime import datetime -from sqlalchemy import func -from sqlalchemy.orm import mapped_column +from sqlalchemy import DateTime, String, Text, func +from sqlalchemy.orm import Mapped, mapped_column from .base import Base from .engine import db @@ -24,7 +25,7 @@ class APIBasedExtension(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - name = mapped_column(db.String(255), nullable=False) - api_endpoint = mapped_column(db.String(255), nullable=False) - api_key = mapped_column(db.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + name: Mapped[str] = mapped_column(String(255), nullable=False) + api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False) + api_key = mapped_column(Text, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/dataset.py b/api/models/dataset.py index 01372f8bf..4d41d0c8b 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -12,7 +12,7 @@ from datetime import datetime from json import JSONDecodeError from typing import Any, Optional, cast -from sqlalchemy import func, select +from sqlalchemy import DateTime, String, func, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column @@ -48,22 +48,22 @@ class Dataset(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) - name: Mapped[str] = mapped_column(db.String(255)) + name: Mapped[str] = mapped_column(String(255)) description = mapped_column(db.Text, nullable=True) - provider: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'vendor'::character varying")) - permission: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'only_me'::character varying")) - data_source_type = mapped_column(db.String(255)) - indexing_technique: Mapped[Optional[str]] = mapped_column(db.String(255)) + provider: Mapped[str] = mapped_column(String(255), server_default=db.text("'vendor'::character varying")) + permission: Mapped[str] = mapped_column(String(255), server_default=db.text("'only_me'::character varying")) + data_source_type = mapped_column(String(255)) + indexing_technique: Mapped[Optional[str]] = mapped_column(String(255)) index_struct = mapped_column(db.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - embedding_model = db.Column(db.String(255), nullable=True) # TODO: mapped_column - embedding_model_provider = db.Column(db.String(255), nullable=True) # TODO: mapped_column + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + embedding_model = db.Column(String(255), nullable=True) # TODO: mapped_column + embedding_model_provider = db.Column(String(255), nullable=True) # TODO: mapped_column collection_binding_id = mapped_column(StringUUID, nullable=True) retrieval_model = mapped_column(JSONB, nullable=True) - built_in_field_enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + built_in_field_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) @property def dataset_keyword_table(self): @@ -268,10 +268,10 @@ class DatasetProcessRule(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False) - mode = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + mode = mapped_column(String(255), nullable=False, server_default=db.text("'automatic'::character varying")) rules = mapped_column(db.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] @@ -313,61 +313,59 @@ class Document(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) - data_source_type = mapped_column(db.String(255), nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False) + data_source_type: Mapped[str] = mapped_column(String(255), nullable=False) data_source_info = mapped_column(db.Text, nullable=True) dataset_process_rule_id = mapped_column(StringUUID, nullable=True) - batch = mapped_column(db.String(255), nullable=False) - name = mapped_column(db.String(255), nullable=False) - created_from = mapped_column(db.String(255), nullable=False) + batch: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + created_from: Mapped[str] = mapped_column(String(255), nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_api_request_id = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) # start processing - processing_started_at = mapped_column(db.DateTime, nullable=True) + processing_started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # parsing file_id = mapped_column(db.Text, nullable=True) - word_count = mapped_column(db.Integer, nullable=True) - parsing_completed_at = mapped_column(db.DateTime, nullable=True) + word_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) # TODO: make this not nullable + parsing_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # cleaning - cleaning_completed_at = mapped_column(db.DateTime, nullable=True) + cleaning_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # split - splitting_completed_at = mapped_column(db.DateTime, nullable=True) + splitting_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # indexing - tokens = mapped_column(db.Integer, nullable=True) - indexing_latency = mapped_column(db.Float, nullable=True) - completed_at = mapped_column(db.DateTime, nullable=True) + tokens: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) + indexing_latency: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # pause - is_paused = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + is_paused: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) paused_by = mapped_column(StringUUID, nullable=True) - paused_at = mapped_column(db.DateTime, nullable=True) + paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # error error = mapped_column(db.Text, nullable=True) - stopped_at = mapped_column(db.DateTime, nullable=True) + stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # basic fields - indexing_status = mapped_column( - db.String(255), nullable=False, server_default=db.text("'waiting'::character varying") - ) - enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - disabled_at = mapped_column(db.DateTime, nullable=True) + indexing_status = mapped_column(String(255), nullable=False, server_default=db.text("'waiting'::character varying")) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - archived = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - archived_reason = mapped_column(db.String(255), nullable=True) + archived: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + archived_reason = mapped_column(String(255), nullable=True) archived_by = mapped_column(StringUUID, nullable=True) - archived_at = mapped_column(db.DateTime, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - doc_type = mapped_column(db.String(40), nullable=True) + archived_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + doc_type = mapped_column(String(40), nullable=True) doc_metadata = mapped_column(JSONB, nullable=True) - doc_form = mapped_column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) - doc_language = mapped_column(db.String(255), nullable=True) + doc_form = mapped_column(String(255), nullable=False, server_default=db.text("'text_model'::character varying")) + doc_language = mapped_column(String(255), nullable=True) DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @@ -524,7 +522,7 @@ class Document(Base): "id": "built-in", "name": BuiltInField.upload_date, "type": "time", - "value": self.created_at.timestamp(), + "value": str(self.created_at.timestamp()), } ) built_in_fields.append( @@ -532,7 +530,7 @@ class Document(Base): "id": "built-in", "name": BuiltInField.last_update_date, "type": "time", - "value": self.updated_at.timestamp(), + "value": str(self.updated_at.timestamp()), } ) built_in_fields.append( @@ -667,23 +665,23 @@ class DocumentSegment(Base): # indexing fields keywords = mapped_column(db.JSON, nullable=True) - index_node_id = mapped_column(db.String(255), nullable=True) - index_node_hash = mapped_column(db.String(255), nullable=True) + index_node_id = mapped_column(String(255), nullable=True) + index_node_hash = mapped_column(String(255), nullable=True) # basic fields - hit_count = mapped_column(db.Integer, nullable=False, default=0) - enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - disabled_at = mapped_column(db.DateTime, nullable=True) + hit_count: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'waiting'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=db.text("'waiting'::character varying")) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - indexing_at = mapped_column(db.DateTime, nullable=True) - completed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) error = mapped_column(db.Text, nullable=True) - stopped_at = mapped_column(db.DateTime, nullable=True) + stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) @property def dataset(self): @@ -808,19 +806,23 @@ class ChildChunk(Base): dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) segment_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False) content = mapped_column(db.Text, nullable=False) - word_count = mapped_column(db.Integer, nullable=False) + word_count: Mapped[int] = mapped_column(db.Integer, nullable=False) # indexing fields - index_node_id = mapped_column(db.String(255), nullable=True) - index_node_hash = mapped_column(db.String(255), nullable=True) - type = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + index_node_id = mapped_column(String(255), nullable=True) + index_node_hash = mapped_column(String(255), nullable=True) + type = mapped_column(String(255), nullable=False, server_default=db.text("'automatic'::character varying")) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - indexing_at = mapped_column(db.DateTime, nullable=True) - completed_at = mapped_column(db.DateTime, nullable=True) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) error = mapped_column(db.Text, nullable=True) @property @@ -846,7 +848,7 @@ class AppDatasetJoin(Base): id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) @property def app(self): @@ -863,11 +865,11 @@ class DatasetQuery(Base): id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False) content = mapped_column(db.Text, nullable=False) - source = mapped_column(db.String(255), nullable=False) + source: Mapped[str] = mapped_column(String(255), nullable=False) source_app_id = mapped_column(StringUUID, nullable=True) - created_by_role = mapped_column(db.String, nullable=False) + created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) class DatasetKeywordTable(Base): @@ -881,7 +883,7 @@ class DatasetKeywordTable(Base): dataset_id = mapped_column(StringUUID, nullable=False, unique=True) keyword_table = mapped_column(db.Text, nullable=False) data_source_type = mapped_column( - db.String(255), nullable=False, server_default=db.text("'database'::character varying") + String(255), nullable=False, server_default=db.text("'database'::character varying") ) @property @@ -925,12 +927,12 @@ class Embedding(Base): id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) model_name = mapped_column( - db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") + String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") ) - hash = mapped_column(db.String(64), nullable=False) + hash = mapped_column(String(64), nullable=False) embedding = mapped_column(db.LargeBinary, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - provider_name = mapped_column(db.String(255), nullable=False, server_default=db.text("''::character varying")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name = mapped_column(String(255), nullable=False, server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) @@ -947,11 +949,11 @@ class DatasetCollectionBinding(Base): ) id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - provider_name = mapped_column(db.String(255), nullable=False) - model_name = mapped_column(db.String(255), nullable=False) - type = mapped_column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) - collection_name = mapped_column(db.String(64), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + type = mapped_column(String(40), server_default=db.text("'dataset'::character varying"), nullable=False) + collection_name = mapped_column(String(64), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TidbAuthBinding(Base): @@ -965,13 +967,13 @@ class TidbAuthBinding(Base): ) id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - cluster_id = mapped_column(db.String(255), nullable=False) - cluster_name = mapped_column(db.String(255), nullable=False) - active = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("CREATING")) - account = mapped_column(db.String(255), nullable=False) - password = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) + cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) + active: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + status = mapped_column(String(255), nullable=False, server_default=db.text("CREATING")) + account: Mapped[str] = mapped_column(String(255), nullable=False) + password: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class Whitelist(Base): @@ -982,8 +984,8 @@ class Whitelist(Base): ) id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - category = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + category: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetPermission(Base): @@ -999,8 +1001,8 @@ class DatasetPermission(Base): dataset_id = mapped_column(StringUUID, nullable=False) account_id = mapped_column(StringUUID, nullable=False) tenant_id = mapped_column(StringUUID, nullable=False) - has_permission = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + has_permission: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ExternalKnowledgeApis(Base): @@ -1012,14 +1014,14 @@ class ExternalKnowledgeApis(Base): ) id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - name = mapped_column(db.String(255), nullable=False) - description = mapped_column(db.String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str] = mapped_column(String(255), nullable=False) tenant_id = mapped_column(StringUUID, nullable=False) settings = mapped_column(db.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) def to_dict(self): return { @@ -1072,9 +1074,9 @@ class ExternalKnowledgeBindings(Base): dataset_id = mapped_column(StringUUID, nullable=False) external_knowledge_id = mapped_column(db.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetAutoDisableLog(Base): @@ -1090,8 +1092,10 @@ class DatasetAutoDisableLog(Base): tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) - notified = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + notified: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) class RateLimitLog(Base): @@ -1104,9 +1108,11 @@ class RateLimitLog(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - subscription_plan = mapped_column(db.String(255), nullable=False) - operation = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + subscription_plan: Mapped[str] = mapped_column(String(255), nullable=False) + operation: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) class DatasetMetadata(Base): @@ -1120,10 +1126,14 @@ class DatasetMetadata(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - type = mapped_column(db.String(255), nullable=False) - name = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + type: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) created_by = mapped_column(StringUUID, nullable=False) updated_by = mapped_column(StringUUID, nullable=True) @@ -1143,5 +1153,5 @@ class DatasetMetadataBinding(Base): dataset_id = mapped_column(StringUUID, nullable=False) metadata_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) created_by = mapped_column(StringUUID, nullable=False) diff --git a/api/models/model.py b/api/models/model.py index 9f6d51b31..fba0d692e 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: import sqlalchemy as sa from flask import request from flask_login import UserMixin -from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text +from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config @@ -37,7 +37,7 @@ class DifySetup(Base): __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) - version = mapped_column(db.String(255), nullable=False) + version: Mapped[str] = mapped_column(String(255), nullable=False) setup_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -73,15 +73,15 @@ class App(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) - name: Mapped[str] = mapped_column(db.String(255)) + name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(db.Text, server_default=db.text("''::character varying")) - mode: Mapped[str] = mapped_column(db.String(255)) - icon_type: Mapped[Optional[str]] = mapped_column(db.String(255)) # image, emoji - icon = db.Column(db.String(255)) - icon_background: Mapped[Optional[str]] = mapped_column(db.String(255)) + mode: Mapped[str] = mapped_column(String(255)) + icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji + icon = db.Column(String(255)) + icon_background: Mapped[Optional[str]] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) workflow_id = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=db.text("'normal'::character varying")) enable_site: Mapped[bool] = mapped_column(db.Boolean) enable_api: Mapped[bool] = mapped_column(db.Boolean) api_rpm: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) @@ -306,8 +306,8 @@ class AppModelConfig(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - provider = mapped_column(db.String(255), nullable=True) - model_id = mapped_column(db.String(255), nullable=True) + provider = mapped_column(String(255), nullable=True) + model_id = mapped_column(String(255), nullable=True) configs = mapped_column(db.JSON, nullable=True) created_by = mapped_column(StringUUID, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -321,12 +321,12 @@ class AppModelConfig(Base): more_like_this = mapped_column(db.Text) model = mapped_column(db.Text) user_input_form = mapped_column(db.Text) - dataset_query_variable = mapped_column(db.String(255)) + dataset_query_variable = mapped_column(String(255)) pre_prompt = mapped_column(db.Text) agent_mode = mapped_column(db.Text) sensitive_word_avoidance = mapped_column(db.Text) retriever_resource = mapped_column(db.Text) - prompt_type = mapped_column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying")) + prompt_type = mapped_column(String(255), nullable=False, server_default=db.text("'simple'::character varying")) chat_prompt_config = mapped_column(db.Text) completion_prompt_config = mapped_column(db.Text) dataset_configs = mapped_column(db.Text) @@ -561,14 +561,14 @@ class RecommendedApp(Base): id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) description = mapped_column(db.JSON, nullable=False) - copyright = mapped_column(db.String(255), nullable=False) - privacy_policy = mapped_column(db.String(255), nullable=False) + copyright: Mapped[str] = mapped_column(String(255), nullable=False) + privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False) custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - category = mapped_column(db.String(255), nullable=False) - position = mapped_column(db.Integer, nullable=False, default=0) - is_listed = mapped_column(db.Boolean, nullable=False, default=True) - install_count = mapped_column(db.Integer, nullable=False, default=0) - language = mapped_column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) + category: Mapped[str] = mapped_column(String(255), nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) + is_listed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=True) + install_count: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) + language = mapped_column(String(255), nullable=False, server_default=db.text("'en-US'::character varying")) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -591,8 +591,8 @@ class InstalledApp(Base): tenant_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False) app_owner_tenant_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False, default=0) - is_pinned = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + position: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) + is_pinned: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) last_used_at = mapped_column(db.DateTime, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -617,26 +617,26 @@ class Conversation(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) app_model_config_id = mapped_column(StringUUID, nullable=True) - model_provider = mapped_column(db.String(255), nullable=True) + model_provider = mapped_column(String(255), nullable=True) override_model_configs = mapped_column(db.Text) - model_id = mapped_column(db.String(255), nullable=True) - mode: Mapped[str] = mapped_column(db.String(255)) - name = mapped_column(db.String(255), nullable=False) + model_id = mapped_column(String(255), nullable=True) + mode: Mapped[str] = mapped_column(String(255)) + name: Mapped[str] = mapped_column(String(255), nullable=False) summary = mapped_column(db.Text) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) introduction = mapped_column(db.Text) system_instruction = mapped_column(db.Text) - system_instruction_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - status = mapped_column(db.String(255), nullable=False) + system_instruction_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + status: Mapped[str] = mapped_column(String(255), nullable=False) # The `invoke_from` records how the conversation is created. # # Its value corresponds to the members of `InvokeFrom`. # (api/core/app/entities/app_invoke_entities.py) - invoke_from = mapped_column(db.String(255), nullable=True) + invoke_from = mapped_column(String(255), nullable=True) # ref: ConversationSource. - from_source = mapped_column(db.String(255), nullable=False) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) read_at = mapped_column(db.DateTime) @@ -650,7 +650,7 @@ class Conversation(Base): "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" ) - is_deleted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + is_deleted: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) @property def inputs(self): @@ -894,8 +894,8 @@ class Message(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - model_provider = mapped_column(db.String(255), nullable=True) - model_id = mapped_column(db.String(255), nullable=True) + model_provider = mapped_column(String(255), nullable=True) + model_id = mapped_column(String(255), nullable=True) override_model_configs = mapped_column(db.Text) conversation_id = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) @@ -911,17 +911,17 @@ class Message(Base): parent_message_id = mapped_column(StringUUID, nullable=True) provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) total_price = mapped_column(db.Numeric(10, 7)) - currency = mapped_column(db.String(255), nullable=False) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + currency: Mapped[str] = mapped_column(String(255), nullable=False) + status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying")) error = mapped_column(db.Text) message_metadata = mapped_column(db.Text) - invoke_from: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - from_source = mapped_column(db.String(255), nullable=False) + invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID) from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - agent_based = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + agent_based: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) @property @@ -1238,9 +1238,9 @@ class MessageFeedback(Base): app_id = mapped_column(StringUUID, nullable=False) conversation_id = mapped_column(StringUUID, nullable=False) message_id = mapped_column(StringUUID, nullable=False) - rating = mapped_column(db.String(255), nullable=False) + rating: Mapped[str] = mapped_column(String(255), nullable=False) content = mapped_column(db.Text) - from_source = mapped_column(db.String(255), nullable=False) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1298,12 +1298,12 @@ class MessageFile(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(db.String(255), nullable=False) - transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False) + type: Mapped[str] = mapped_column(String(255), nullable=False) + transfer_method: Mapped[str] = mapped_column(String(255), nullable=False) url: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) - belongs_to: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) - created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1323,7 +1323,7 @@ class MessageAnnotation(Base): message_id: Mapped[Optional[str]] = mapped_column(StringUUID) question = db.Column(db.Text, nullable=True) content = mapped_column(db.Text, nullable=False) - hit_count = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + hit_count: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) account_id = mapped_column(StringUUID, nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1415,10 +1415,10 @@ class OperationLog(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) account_id = mapped_column(StringUUID, nullable=False) - action = mapped_column(db.String(255), nullable=False) + action: Mapped[str] = mapped_column(String(255), nullable=False) content = mapped_column(db.JSON) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_ip = mapped_column(db.String(255), nullable=False) + created_ip: Mapped[str] = mapped_column(String(255), nullable=False) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1433,10 +1433,10 @@ class EndUser(Base, UserMixin): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(255), nullable=False) - external_user_id = mapped_column(db.String(255), nullable=True) - name = mapped_column(db.String(255)) - is_anonymous = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + type: Mapped[str] = mapped_column(String(255), nullable=False) + external_user_id = mapped_column(String(255), nullable=True) + name = mapped_column(String(255)) + is_anonymous: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) session_id: Mapped[str] = mapped_column() created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1452,10 +1452,10 @@ class AppMCPServer(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False) - name = mapped_column(db.String(255), nullable=False) - description = mapped_column(db.String(255), nullable=False) - server_code = mapped_column(db.String(255), nullable=False) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str] = mapped_column(String(255), nullable=False) + server_code: Mapped[str] = mapped_column(String(255), nullable=False) + status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying")) parameters = mapped_column(db.Text, nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1485,28 +1485,28 @@ class Site(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - title = mapped_column(db.String(255), nullable=False) - icon_type = mapped_column(db.String(255), nullable=True) - icon = mapped_column(db.String(255)) - icon_background = mapped_column(db.String(255)) + title: Mapped[str] = mapped_column(String(255), nullable=False) + icon_type = mapped_column(String(255), nullable=True) + icon = mapped_column(String(255)) + icon_background = mapped_column(String(255)) description = mapped_column(db.Text) - default_language = mapped_column(db.String(255), nullable=False) - chat_color_theme = mapped_column(db.String(255)) - chat_color_theme_inverted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - copyright = mapped_column(db.String(255)) - privacy_policy = mapped_column(db.String(255)) - show_workflow_steps = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - use_icon_as_answer_icon = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + default_language: Mapped[str] = mapped_column(String(255), nullable=False) + chat_color_theme = mapped_column(String(255)) + chat_color_theme_inverted: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + copyright = mapped_column(String(255)) + privacy_policy = mapped_column(String(255)) + show_workflow_steps: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="") - customize_domain = mapped_column(db.String(255)) - customize_token_strategy = mapped_column(db.String(255), nullable=False) - prompt_public = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + customize_domain = mapped_column(String(255)) + customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False) + prompt_public: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying")) created_by = mapped_column(StringUUID, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - code = mapped_column(db.String(255)) + code = mapped_column(String(255)) @property def custom_disclaimer(self): @@ -1544,8 +1544,8 @@ class ApiToken(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=True) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(16), nullable=False) - token = mapped_column(db.String(255), nullable=False) + type = mapped_column(String(16), nullable=False) + token: Mapped[str] = mapped_column(String(255), nullable=False) last_used_at = mapped_column(db.DateTime, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1567,21 +1567,21 @@ class UploadFile(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - storage_type: Mapped[str] = mapped_column(db.String(255), nullable=False) - key: Mapped[str] = mapped_column(db.String(255), nullable=False) - name: Mapped[str] = mapped_column(db.String(255), nullable=False) + storage_type: Mapped[str] = mapped_column(String(255), nullable=False) + key: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) size: Mapped[int] = mapped_column(db.Integer, nullable=False) - extension: Mapped[str] = mapped_column(db.String(255), nullable=False) - mime_type: Mapped[str] = mapped_column(db.String(255), nullable=True) + extension: Mapped[str] = mapped_column(String(255), nullable=False) + mime_type: Mapped[str] = mapped_column(String(255), nullable=True) created_by_role: Mapped[str] = mapped_column( - db.String(255), nullable=False, server_default=db.text("'account'::character varying") + String(255), nullable=False, server_default=db.text("'account'::character varying") ) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) used: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) used_at: Mapped[datetime | None] = mapped_column(db.DateTime, nullable=True) - hash: Mapped[str | None] = mapped_column(db.String(255), nullable=True) + hash: Mapped[str | None] = mapped_column(String(255), nullable=True) source_url: Mapped[str] = mapped_column(sa.TEXT, default="") def __init__( @@ -1630,10 +1630,10 @@ class ApiRequest(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) api_token_id = mapped_column(StringUUID, nullable=False) - path = mapped_column(db.String(255), nullable=False) + path: Mapped[str] = mapped_column(String(255), nullable=False) request = mapped_column(db.Text, nullable=True) response = mapped_column(db.Text, nullable=True) - ip = mapped_column(db.String(255), nullable=False) + ip: Mapped[str] = mapped_column(String(255), nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1646,7 +1646,7 @@ class MessageChain(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) - type = mapped_column(db.String(255), nullable=False) + type: Mapped[str] = mapped_column(String(255), nullable=False) input = mapped_column(db.Text, nullable=True) output = mapped_column(db.Text, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -1663,7 +1663,7 @@ class MessageAgentThought(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) message_chain_id = mapped_column(StringUUID, nullable=True) - position = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False) thought = mapped_column(db.Text, nullable=True) tool = mapped_column(db.Text, nullable=True) tool_labels_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text")) @@ -1673,19 +1673,19 @@ class MessageAgentThought(Base): # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design tool_process_data = mapped_column(db.Text, nullable=True) message = mapped_column(db.Text, nullable=True) - message_token = mapped_column(db.Integer, nullable=True) + message_token: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) message_unit_price = mapped_column(db.Numeric, nullable=True) message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) message_files = mapped_column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True) - answer_token = mapped_column(db.Integer, nullable=True) + answer_token: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) answer_unit_price = mapped_column(db.Numeric, nullable=True) answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - tokens = mapped_column(db.Integer, nullable=True) + tokens: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) total_price = mapped_column(db.Numeric, nullable=True) - currency = mapped_column(db.String, nullable=True) - latency = mapped_column(db.Float, nullable=True) - created_by_role = mapped_column(db.String, nullable=False) + currency = mapped_column(String, nullable=True) + latency: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True) + created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -1775,18 +1775,18 @@ class DatasetRetrieverResource(Base): id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(db.Integer, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) dataset_name = mapped_column(db.Text, nullable=False) document_id = mapped_column(StringUUID, nullable=True) document_name = mapped_column(db.Text, nullable=False) data_source_type = mapped_column(db.Text, nullable=True) segment_id = mapped_column(StringUUID, nullable=True) - score = mapped_column(db.Float, nullable=True) + score: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True) content = mapped_column(db.Text, nullable=False) - hit_count = mapped_column(db.Integer, nullable=True) - word_count = mapped_column(db.Integer, nullable=True) - segment_position = mapped_column(db.Integer, nullable=True) + hit_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) + word_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) + segment_position: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) index_node_hash = mapped_column(db.Text, nullable=True) retriever_from = mapped_column(db.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) @@ -1805,8 +1805,8 @@ class Tag(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(16), nullable=False) - name = mapped_column(db.String(255), nullable=False) + type = mapped_column(String(16), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1836,13 +1836,13 @@ class TraceAppConfig(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - tracing_provider = mapped_column(db.String(255), nullable=True) + tracing_provider = mapped_column(String(255), nullable=True) tracing_config = mapped_column(db.JSON, nullable=True) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column( db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - is_active = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + is_active: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) @property def tracing_config_dict(self): diff --git a/api/models/provider.py b/api/models/provider.py index 1e25f0c90..7bfc249b0 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -2,7 +2,7 @@ from datetime import datetime from enum import Enum from typing import Optional -from sqlalchemy import func, text +from sqlalchemy import DateTime, String, func, text from sqlalchemy.orm import Mapped, mapped_column from .base import Base @@ -56,22 +56,22 @@ class Provider(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) provider_type: Mapped[str] = mapped_column( - db.String(40), nullable=False, server_default=text("'custom'::character varying") + String(40), nullable=False, server_default=text("'custom'::character varying") ) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) quota_type: Mapped[Optional[str]] = mapped_column( - db.String(40), nullable=True, server_default=text("''::character varying") + String(40), nullable=True, server_default=text("''::character varying") ) quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True) quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -113,13 +113,13 @@ class ProviderModel(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TenantDefaultModel(Base): @@ -131,11 +131,11 @@ class TenantDefaultModel(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TenantPreferredModelProvider(Base): @@ -147,10 +147,10 @@ class TenantPreferredModelProvider(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderOrder(Base): @@ -162,22 +162,22 @@ class ProviderOrder(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False) - payment_id: Mapped[Optional[str]] = mapped_column(db.String(191)) - transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191)) + payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False) + payment_id: Mapped[Optional[str]] = mapped_column(String(191)) + transaction_id: Mapped[Optional[str]] = mapped_column(String(191)) quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1")) - currency: Mapped[Optional[str]] = mapped_column(db.String(40)) + currency: Mapped[Optional[str]] = mapped_column(String(40)) total_amount: Mapped[Optional[int]] = mapped_column(db.Integer) payment_status: Mapped[str] = mapped_column( - db.String(40), nullable=False, server_default=text("'wait_pay'::character varying") + String(40), nullable=False, server_default=text("'wait_pay'::character varying") ) - paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + pay_failed_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + refunded_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderModelSetting(Base): @@ -193,13 +193,13 @@ class ProviderModelSetting(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class LoadBalancingModelConfig(Base): @@ -215,11 +215,11 @@ class LoadBalancingModelConfig(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index 100e0d96e..8191c874a 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,8 +1,10 @@ import json +from datetime import datetime +from typing import Optional -from sqlalchemy import func +from sqlalchemy import DateTime, String, func from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import Mapped, mapped_column from models.base import Base @@ -20,12 +22,12 @@ class DataSourceOauthBinding(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - access_token = mapped_column(db.String(255), nullable=False) - provider = mapped_column(db.String(255), nullable=False) + access_token: Mapped[str] = mapped_column(String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) source_info = mapped_column(JSONB, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + disabled: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) class DataSourceApiKeyAuthBinding(Base): @@ -38,12 +40,12 @@ class DataSourceApiKeyAuthBinding(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - category = mapped_column(db.String(255), nullable=False) - provider = mapped_column(db.String(255), nullable=False) + category: Mapped[str] = mapped_column(String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) credentials = mapped_column(db.Text, nullable=True) # JSON - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + disabled: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) def to_dict(self): return { diff --git a/api/models/task.py b/api/models/task.py index 3e5ebd209..66a47ea4d 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import Optional from celery import states # type: ignore +from sqlalchemy import DateTime, String from sqlalchemy.orm import Mapped, mapped_column from libs.datetime_utils import naive_utc_now @@ -16,22 +17,22 @@ class CeleryTask(Base): __tablename__ = "celery_taskmeta" id = mapped_column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) - task_id = mapped_column(db.String(155), unique=True) - status = mapped_column(db.String(50), default=states.PENDING) + task_id = mapped_column(String(155), unique=True) + status = mapped_column(String(50), default=states.PENDING) result = mapped_column(db.PickleType, nullable=True) date_done = mapped_column( - db.DateTime, + DateTime, default=lambda: naive_utc_now(), onupdate=lambda: naive_utc_now(), nullable=True, ) traceback = mapped_column(db.Text, nullable=True) - name = mapped_column(db.String(155), nullable=True) + name = mapped_column(String(155), nullable=True) args = mapped_column(db.LargeBinary, nullable=True) kwargs = mapped_column(db.LargeBinary, nullable=True) - worker = mapped_column(db.String(155), nullable=True) - retries = mapped_column(db.Integer, nullable=True) - queue = mapped_column(db.String(155), nullable=True) + worker = mapped_column(String(155), nullable=True) + retries: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) + queue = mapped_column(String(155), nullable=True) class CeleryTaskSet(Base): @@ -42,6 +43,6 @@ class CeleryTaskSet(Base): id: Mapped[int] = mapped_column( db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True ) - taskset_id = mapped_column(db.String(155), unique=True) + taskset_id = mapped_column(String(155), unique=True) result = mapped_column(db.PickleType, nullable=True) - date_done: Mapped[Optional[datetime]] = mapped_column(db.DateTime, default=lambda: naive_utc_now(), nullable=True) + date_done: Mapped[Optional[datetime]] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 68f4211e5..1491cd90c 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -5,7 +5,7 @@ from urllib.parse import urlparse import sqlalchemy as sa from deprecated import deprecated -from sqlalchemy import ForeignKey, func +from sqlalchemy import ForeignKey, String, func from sqlalchemy.orm import Mapped, mapped_column from core.file import helpers as file_helpers @@ -30,8 +30,8 @@ class ToolOAuthSystemClient(Base): ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) - provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + plugin_id = mapped_column(String(512), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) # oauth params of the tool provider encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) @@ -47,8 +47,8 @@ class ToolOAuthTenantClient(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) - provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + plugin_id: Mapped[str] = mapped_column(String(512), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) # oauth params of the tool provider encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) @@ -72,26 +72,26 @@ class BuiltinToolProvider(Base): # id of the tool provider id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column( - db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying") + String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying") ) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # who created this tool provider user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # name of the tool provider - provider: Mapped[str] = mapped_column(db.String(256), nullable=False) + provider: Mapped[str] = mapped_column(String(256), nullable=False) # credential of the tool provider encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) # credential type, e.g., "api-key", "oauth2" credential_type: Mapped[str] = mapped_column( - db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") + String(32), nullable=False, server_default=db.text("'api-key'::character varying") ) expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1")) @@ -113,12 +113,12 @@ class ApiToolProvider(Base): id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider - name = mapped_column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) + name = mapped_column(String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) # icon - icon = mapped_column(db.String(255), nullable=False) + icon: Mapped[str] = mapped_column(String(255), nullable=False) # original schema schema = mapped_column(db.Text, nullable=False) - schema_type_str: Mapped[str] = mapped_column(db.String(40), nullable=False) + schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False) # who created this tool user_id = mapped_column(StringUUID, nullable=False) # tenant id @@ -130,12 +130,12 @@ class ApiToolProvider(Base): # json format credentials credentials_str = mapped_column(db.Text, nullable=False) # privacy policy - privacy_policy = mapped_column(db.String(255), nullable=True) + privacy_policy = mapped_column(String(255), nullable=True) # custom_disclaimer custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def schema_type(self) -> ApiProviderSchemaType: @@ -173,11 +173,11 @@ class ToolLabelBinding(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # tool id - tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False) + tool_id: Mapped[str] = mapped_column(String(64), nullable=False) # tool type - tool_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + tool_type: Mapped[str] = mapped_column(String(40), nullable=False) # label name - label_name: Mapped[str] = mapped_column(db.String(40), nullable=False) + label_name: Mapped[str] = mapped_column(String(40), nullable=False) class WorkflowToolProvider(Base): @@ -194,15 +194,15 @@ class WorkflowToolProvider(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the workflow provider - name: Mapped[str] = mapped_column(db.String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) # label of the workflow provider - label: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="") + label: Mapped[str] = mapped_column(String(255), nullable=False, server_default="") # icon - icon: Mapped[str] = mapped_column(db.String(255), nullable=False) + icon: Mapped[str] = mapped_column(String(255), nullable=False) # app id of the workflow provider app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # version of the workflow provider - version: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="") + version: Mapped[str] = mapped_column(String(255), nullable=False, server_default="") # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # tenant id @@ -212,13 +212,13 @@ class WorkflowToolProvider(Base): # parameter configuration parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]") # privacy policy - privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default="") + privacy_policy: Mapped[str] = mapped_column(String(255), nullable=True, server_default="") created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) @property @@ -253,15 +253,15 @@ class MCPToolProvider(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the mcp provider - name: Mapped[str] = mapped_column(db.String(40), nullable=False) + name: Mapped[str] = mapped_column(String(40), nullable=False) # server identifier of the mcp provider - server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False) + server_identifier: Mapped[str] = mapped_column(String(64), nullable=False) # encrypted url of the mcp provider server_url: Mapped[str] = mapped_column(db.Text, nullable=False) # hash of server_url for uniqueness check - server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False) + server_url_hash: Mapped[str] = mapped_column(String(64), nullable=False) # icon of the mcp provider - icon: Mapped[str] = mapped_column(db.String(255), nullable=True) + icon: Mapped[str] = mapped_column(String(255), nullable=True) # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # who created this tool @@ -273,10 +273,10 @@ class MCPToolProvider(Base): # tools tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]") created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) def load_user(self) -> Account | None: @@ -355,11 +355,11 @@ class ToolModelInvoke(Base): # tenant id tenant_id = mapped_column(StringUUID, nullable=False) # provider - provider = mapped_column(db.String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) # type - tool_type = mapped_column(db.String(40), nullable=False) + tool_type = mapped_column(String(40), nullable=False) # tool name - tool_name = mapped_column(db.String(128), nullable=False) + tool_name = mapped_column(String(128), nullable=False) # invoke parameters model_parameters = mapped_column(db.Text, nullable=False) # prompt messages @@ -367,15 +367,15 @@ class ToolModelInvoke(Base): # invoke response model_response = mapped_column(db.Text, nullable=False) - prompt_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - answer_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + prompt_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) total_price = mapped_column(db.Numeric(10, 7)) - currency = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + currency: Mapped[str] = mapped_column(String(255), nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @deprecated @@ -402,8 +402,8 @@ class ToolConversationVariables(Base): # variables pool variables_str = mapped_column(db.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def variables(self) -> Any: @@ -429,11 +429,11 @@ class ToolFile(Base): # conversation id conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # file key - file_key: Mapped[str] = mapped_column(db.String(255), nullable=False) + file_key: Mapped[str] = mapped_column(String(255), nullable=False) # mime type - mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False) + mimetype: Mapped[str] = mapped_column(String(255), nullable=False) # original url - original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True) + original_url: Mapped[str] = mapped_column(String(2048), nullable=True) # name name: Mapped[str] = mapped_column(default="") # size @@ -465,13 +465,13 @@ class DeprecatedPublishedAppTool(Base): # to describe this parameter to llm, we need this field query_description = mapped_column(db.Text, nullable=False) # query name, the name of the query parameter - query_name = mapped_column(db.String(40), nullable=False) + query_name = mapped_column(String(40), nullable=False) # name of the tool provider - tool_name = mapped_column(db.String(40), nullable=False) + tool_name = mapped_column(String(40), nullable=False) # author - author = mapped_column(db.String(40), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + author = mapped_column(String(40), nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def description_i18n(self) -> I18nObject: diff --git a/api/models/web.py b/api/models/web.py index ce00f4010..1bf9b5c76 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,4 +1,6 @@ -from sqlalchemy import func +from datetime import datetime + +from sqlalchemy import DateTime, String, func from sqlalchemy.orm import Mapped, mapped_column from models.base import Base @@ -19,10 +21,10 @@ class SavedMessage(Base): app_id = mapped_column(StringUUID, nullable=False) message_id = mapped_column(StringUUID, nullable=False) created_by_role = mapped_column( - db.String(255), nullable=False, server_default=db.text("'end_user'::character varying") + String(255), nullable=False, server_default=db.text("'end_user'::character varying") ) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @property def message(self): @@ -40,7 +42,7 @@ class PinnedConversation(Base): app_id = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID) created_by_role = mapped_column( - db.String(255), nullable=False, server_default=db.text("'end_user'::character varying") + String(255), nullable=False, server_default=db.text("'end_user'::character varying") ) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index d89db6c7d..6c7d061bb 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 from flask_login import current_user -from sqlalchemy import orm +from sqlalchemy import DateTime, orm from core.file.constants import maybe_file_object from core.file.models import File @@ -25,7 +25,7 @@ if TYPE_CHECKING: from models.model import AppMode import sqlalchemy as sa -from sqlalchemy import Index, PrimaryKeyConstraint, UniqueConstraint, func +from sqlalchemy import Index, PrimaryKeyConstraint, String, UniqueConstraint, func from sqlalchemy.orm import Mapped, declared_attr, mapped_column from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE @@ -124,17 +124,17 @@ class Workflow(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(db.String(255), nullable=False) - version: Mapped[str] = mapped_column(db.String(255), nullable=False) + type: Mapped[str] = mapped_column(String(255), nullable=False) + version: Mapped[str] = mapped_column(String(255), nullable=False) marked_name: Mapped[str] = mapped_column(default="", server_default="") marked_comment: Mapped[str] = mapped_column(default="", server_default="") graph: Mapped[str] = mapped_column(sa.Text) _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=naive_utc_now(), server_onupdate=func.current_timestamp(), @@ -500,21 +500,21 @@ class WorkflowRun(Base): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - type: Mapped[str] = mapped_column(db.String(255)) - triggered_from: Mapped[str] = mapped_column(db.String(255)) - version: Mapped[str] = mapped_column(db.String(255)) + type: Mapped[str] = mapped_column(String(255)) + triggered_from: Mapped[str] = mapped_column(String(255)) + version: Mapped[str] = mapped_column(String(255)) graph: Mapped[Optional[str]] = mapped_column(db.Text) inputs: Mapped[Optional[str]] = mapped_column(db.Text) - status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded + status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") error: Mapped[Optional[str]] = mapped_column(db.Text) elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) - created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user + created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) @property @@ -708,25 +708,25 @@ class WorkflowNodeExecutionModel(Base): tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - triggered_from: Mapped[str] = mapped_column(db.String(255)) + triggered_from: Mapped[str] = mapped_column(String(255)) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) index: Mapped[int] = mapped_column(db.Integer) - predecessor_node_id: Mapped[Optional[str]] = mapped_column(db.String(255)) - node_execution_id: Mapped[Optional[str]] = mapped_column(db.String(255)) - node_id: Mapped[str] = mapped_column(db.String(255)) - node_type: Mapped[str] = mapped_column(db.String(255)) - title: Mapped[str] = mapped_column(db.String(255)) + predecessor_node_id: Mapped[Optional[str]] = mapped_column(String(255)) + node_execution_id: Mapped[Optional[str]] = mapped_column(String(255)) + node_id: Mapped[str] = mapped_column(String(255)) + node_type: Mapped[str] = mapped_column(String(255)) + title: Mapped[str] = mapped_column(String(255)) inputs: Mapped[Optional[str]] = mapped_column(db.Text) process_data: Mapped[Optional[str]] = mapped_column(db.Text) outputs: Mapped[Optional[str]] = mapped_column(db.Text) - status: Mapped[str] = mapped_column(db.String(255)) + status: Mapped[str] = mapped_column(String(255)) error: Mapped[Optional[str]] = mapped_column(db.Text) elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0")) execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - created_by_role: Mapped[str] = mapped_column(db.String(255)) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + created_by_role: Mapped[str] = mapped_column(String(255)) created_by: Mapped[str] = mapped_column(StringUUID) - finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) @property def created_by_account(self): @@ -843,10 +843,10 @@ class WorkflowAppLog(Base): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) - created_from: Mapped[str] = mapped_column(db.String(255), nullable=False) - created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_from: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @property def workflow_run(self): @@ -873,10 +873,10 @@ class ConversationVariable(Base): app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) data: Mapped[str] = mapped_column(db.Text, nullable=False) created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True + DateTime, nullable=False, server_default=func.current_timestamp(), index=True ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: @@ -936,14 +936,14 @@ class WorkflowDraftVariable(Base): id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) created_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=_naive_utc_datetime, server_default=func.current_timestamp(), ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=_naive_utc_datetime, server_default=func.current_timestamp(), @@ -958,7 +958,7 @@ class WorkflowDraftVariable(Base): # # If it's not edited after creation, its value is `None`. last_edited_at: Mapped[datetime | None] = mapped_column( - db.DateTime, + DateTime, nullable=True, default=None, ) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 128039999..da475a18f 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2040,6 +2040,7 @@ class SegmentService: db.session.add(segment_document) # update document word count + assert document.word_count is not None document.word_count += segment_document.word_count db.session.add(document) db.session.commit() @@ -2124,6 +2125,7 @@ class SegmentService: else: keywords_list.append(None) # update document word count + assert document.word_count is not None document.word_count += increment_word_count db.session.add(document) try: @@ -2185,6 +2187,7 @@ class SegmentService: db.session.commit() # update document word count if word_count_change != 0: + assert document.word_count is not None document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) # update segment index task @@ -2260,6 +2263,7 @@ class SegmentService: word_count_change = segment.word_count - word_count_change # update document word count if word_count_change != 0: + assert document.word_count is not None document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) db.session.add(segment) @@ -2323,6 +2327,7 @@ class SegmentService: delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id) db.session.delete(segment) # update document word count + assert document.word_count is not None document.word_count -= segment.word_count db.session.add(document) db.session.commit() diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 714e30acc..dee43cd85 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -134,6 +134,7 @@ def batch_create_segment_to_index_task( db.session.add(segment_document) document_segments.append(segment_document) # update document word count + assert dataset_document.word_count is not None dataset_document.word_count += word_count_change db.session.add(dataset_document) # add index to db