add pgvecto_rs support and upgrade SQLAlchemy (#3833)

This commit is contained in:
Jyong
2024-04-29 11:58:17 +08:00
committed by GitHub
parent 975b2fb79e
commit 3e9dbe3e0a
26 changed files with 584 additions and 220 deletions

View File

@@ -4,10 +4,11 @@ import pickle
from json import JSONDecodeError
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.dialects.postgresql import JSONB
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import StringUUID
from models.account import Account
from models.model import App, Tag, TagBinding, UploadFile
@@ -22,8 +23,8 @@ class Dataset(db.Model):
INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None]
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=True)
provider = db.Column(db.String(255), nullable=False,
@@ -33,15 +34,15 @@ class Dataset(db.Model):
data_source_type = db.Column(db.String(255))
indexing_technique = db.Column(db.String(255), nullable=True)
index_struct = db.Column(db.Text, nullable=True)
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(UUID, nullable=True)
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(UUID, nullable=True)
collection_binding_id = db.Column(StringUUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True)
@property
@@ -145,13 +146,13 @@ class DatasetProcessRule(db.Model):
db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'),
)
id = db.Column(UUID, nullable=False,
id = db.Column(StringUUID, nullable=False,
server_default=db.text('uuid_generate_v4()'))
dataset_id = db.Column(UUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
mode = db.Column(db.String(255), nullable=False,
server_default=db.text("'automatic'::character varying"))
rules = db.Column(db.Text, nullable=True)
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
@@ -197,19 +198,19 @@ class Document(db.Model):
)
# initial fields
id = db.Column(UUID, nullable=False,
id = db.Column(StringUUID, nullable=False,
server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
dataset_id = db.Column(UUID, nullable=False)
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False)
data_source_type = db.Column(db.String(255), nullable=False)
data_source_info = db.Column(db.Text, nullable=True)
dataset_process_rule_id = db.Column(UUID, nullable=True)
dataset_process_rule_id = db.Column(StringUUID, nullable=True)
batch = db.Column(db.String(255), nullable=False)
name = db.Column(db.String(255), nullable=False)
created_from = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False)
created_api_request_id = db.Column(UUID, nullable=True)
created_by = db.Column(StringUUID, nullable=False)
created_api_request_id = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
@@ -234,7 +235,7 @@ class Document(db.Model):
# pause
is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
paused_by = db.Column(UUID, nullable=True)
paused_by = db.Column(StringUUID, nullable=True)
paused_at = db.Column(db.DateTime, nullable=True)
# error
@@ -247,11 +248,11 @@ class Document(db.Model):
enabled = db.Column(db.Boolean, nullable=False,
server_default=db.text('true'))
disabled_at = db.Column(db.DateTime, nullable=True)
disabled_by = db.Column(UUID, nullable=True)
disabled_by = db.Column(StringUUID, nullable=True)
archived = db.Column(db.Boolean, nullable=False,
server_default=db.text('false'))
archived_reason = db.Column(db.String(255), nullable=True)
archived_by = db.Column(UUID, nullable=True)
archived_by = db.Column(StringUUID, nullable=True)
archived_at = db.Column(db.DateTime, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
@@ -356,11 +357,11 @@ class DocumentSegment(db.Model):
)
# initial fields
id = db.Column(UUID, nullable=False,
id = db.Column(StringUUID, nullable=False,
server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
dataset_id = db.Column(UUID, nullable=False)
document_id = db.Column(UUID, nullable=False)
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False)
content = db.Column(db.Text, nullable=False)
answer = db.Column(db.Text, nullable=True)
@@ -377,13 +378,13 @@ class DocumentSegment(db.Model):
enabled = db.Column(db.Boolean, nullable=False,
server_default=db.text('true'))
disabled_at = db.Column(db.DateTime, nullable=True)
disabled_by = db.Column(UUID, nullable=True)
disabled_by = db.Column(StringUUID, nullable=True)
status = db.Column(db.String(255), nullable=False,
server_default=db.text("'waiting'::character varying"))
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(UUID, nullable=True)
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
indexing_at = db.Column(db.DateTime, nullable=True)
@@ -421,9 +422,9 @@ class AppDatasetJoin(db.Model):
db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'),
)
id = db.Column(UUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
dataset_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
@property
@@ -438,13 +439,13 @@ class DatasetQuery(db.Model):
db.Index('dataset_query_dataset_id_idx', 'dataset_id'),
)
id = db.Column(UUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
dataset_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
dataset_id = db.Column(StringUUID, nullable=False)
content = db.Column(db.Text, nullable=False)
source = db.Column(db.String(255), nullable=False)
source_app_id = db.Column(UUID, nullable=True)
source_app_id = db.Column(StringUUID, nullable=True)
created_by_role = db.Column(db.String, nullable=False)
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
@@ -455,8 +456,8 @@ class DatasetKeywordTable(db.Model):
db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'),
)
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
dataset_id = db.Column(UUID, nullable=False, unique=True)
id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
dataset_id = db.Column(StringUUID, nullable=False, unique=True)
keyword_table = db.Column(db.Text, nullable=False)
data_source_type = db.Column(db.String(255), nullable=False,
server_default=db.text("'database'::character varying"))
@@ -501,7 +502,7 @@ class Embedding(db.Model):
db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx')
)
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
model_name = db.Column(db.String(40), nullable=False,
server_default=db.text("'text-embedding-ada-002'::character varying"))
hash = db.Column(db.String(64), nullable=False)
@@ -525,7 +526,7 @@ class DatasetCollectionBinding(db.Model):
)
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(40), nullable=False)
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)