feat: support tidb vector (#4588)
This commit is contained in:
@@ -112,6 +112,13 @@ PGVECTOR_USER=postgres
|
|||||||
PGVECTOR_PASSWORD=postgres
|
PGVECTOR_PASSWORD=postgres
|
||||||
PGVECTOR_DATABASE=postgres
|
PGVECTOR_DATABASE=postgres
|
||||||
|
|
||||||
|
# Tidb Vector configuration
|
||||||
|
TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com
|
||||||
|
TIDB_VECTOR_PORT=4000
|
||||||
|
TIDB_VECTOR_USER=xxx.root
|
||||||
|
TIDB_VECTOR_PASSWORD=xxxxxx
|
||||||
|
TIDB_VECTOR_DATABASE=dify
|
||||||
|
|
||||||
# Upload configuration
|
# Upload configuration
|
||||||
UPLOAD_FILE_SIZE_LIMIT=15
|
UPLOAD_FILE_SIZE_LIMIT=15
|
||||||
UPLOAD_FILE_BATCH_LIMIT=5
|
UPLOAD_FILE_BATCH_LIMIT=5
|
||||||
|
@@ -299,6 +299,13 @@ class Config:
|
|||||||
self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD')
|
self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD')
|
||||||
self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE')
|
self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE')
|
||||||
|
|
||||||
|
# tidb-vector settings
|
||||||
|
self.TIDB_VECTOR_HOST = get_env('TIDB_VECTOR_HOST')
|
||||||
|
self.TIDB_VECTOR_PORT = get_env('TIDB_VECTOR_PORT')
|
||||||
|
self.TIDB_VECTOR_USER = get_env('TIDB_VECTOR_USER')
|
||||||
|
self.TIDB_VECTOR_PASSWORD = get_env('TIDB_VECTOR_PASSWORD')
|
||||||
|
self.TIDB_VECTOR_DATABASE = get_env('TIDB_VECTOR_DATABASE')
|
||||||
|
|
||||||
# ------------------------
|
# ------------------------
|
||||||
# Mail Configurations.
|
# Mail Configurations.
|
||||||
# ------------------------
|
# ------------------------
|
||||||
|
@@ -476,7 +476,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
vector_type = current_app.config['VECTOR_STORE']
|
vector_type = current_app.config['VECTOR_STORE']
|
||||||
if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs"}:
|
if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs", 'tidb_vector'}:
|
||||||
return {
|
return {
|
||||||
'retrieval_method': [
|
'retrieval_method': [
|
||||||
'semantic_search'
|
'semantic_search'
|
||||||
@@ -497,7 +497,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, vector_type):
|
def get(self, vector_type):
|
||||||
if vector_type in {'milvus', 'relyt', 'pgvector'}:
|
if vector_type in {'milvus', 'relyt', 'pgvector', 'tidb_vector'}:
|
||||||
return {
|
return {
|
||||||
'retrieval_method': [
|
'retrieval_method': [
|
||||||
'semantic_search'
|
'semantic_search'
|
||||||
|
0
api/core/rag/datasource/vdb/tidb_vector/__init__.py
Normal file
0
api/core/rag/datasource/vdb/tidb_vector/__init__.py
Normal file
214
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
Normal file
214
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from pydantic import BaseModel, root_validator
|
||||||
|
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
||||||
|
from sqlalchemy import text as sql_text
|
||||||
|
from sqlalchemy.orm import Session, declarative_base
|
||||||
|
|
||||||
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TiDBVectorConfig(BaseModel):
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
user: str
|
||||||
|
password: str
|
||||||
|
database: str
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_config(cls, values: dict) -> dict:
|
||||||
|
if not values['host']:
|
||||||
|
raise ValueError("config TIDB_VECTOR_HOST is required")
|
||||||
|
if not values['port']:
|
||||||
|
raise ValueError("config TIDB_VECTOR_PORT is required")
|
||||||
|
if not values['user']:
|
||||||
|
raise ValueError("config TIDB_VECTOR_USER is required")
|
||||||
|
if not values['password']:
|
||||||
|
raise ValueError("config TIDB_VECTOR_PASSWORD is required")
|
||||||
|
if not values['database']:
|
||||||
|
raise ValueError("config TIDB_VECTOR_DATABASE is required")
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class TiDBVector(BaseVector):
|
||||||
|
|
||||||
|
def _table(self, dim: int) -> Table:
|
||||||
|
from tidb_vector.sqlalchemy import VectorType
|
||||||
|
return Table(
|
||||||
|
self._collection_name,
|
||||||
|
self._orm_base.metadata,
|
||||||
|
Column('id', String(36), primary_key=True, nullable=False),
|
||||||
|
Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"),
|
||||||
|
Column("text", TEXT, nullable=False),
|
||||||
|
Column("meta", JSON, nullable=False),
|
||||||
|
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
|
||||||
|
Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
|
||||||
|
extend_existing=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'):
|
||||||
|
super().__init__(collection_name)
|
||||||
|
self._client_config = config
|
||||||
|
self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
|
||||||
|
f"ssl_verify_cert=true&ssl_verify_identity=true")
|
||||||
|
self._distance_func = distance_func.lower()
|
||||||
|
self._engine = create_engine(self._url)
|
||||||
|
self._orm_base = declarative_base()
|
||||||
|
self._dimension = 1536
|
||||||
|
|
||||||
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
logger.info("create collection and add texts, collection_name: " + self._collection_name)
|
||||||
|
self._create_collection(len(embeddings[0]))
|
||||||
|
self.add_texts(texts, embeddings)
|
||||||
|
self._dimension = len(embeddings[0])
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _create_collection(self, dimension: int):
|
||||||
|
logger.info("_create_collection, collection_name " + self._collection_name)
|
||||||
|
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||||
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||||
|
if redis_client.get(collection_exist_cache_key):
|
||||||
|
return
|
||||||
|
with Session(self._engine) as session:
|
||||||
|
session.begin()
|
||||||
|
drop_statement = sql_text(f"""DROP TABLE IF EXISTS {self._collection_name}; """)
|
||||||
|
session.execute(drop_statement)
|
||||||
|
create_statement = sql_text(f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
||||||
|
id CHAR(36) PRIMARY KEY,
|
||||||
|
text TEXT NOT NULL,
|
||||||
|
meta JSON NOT NULL,
|
||||||
|
vector VECTOR<FLOAT>({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})",
|
||||||
|
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
session.execute(create_statement)
|
||||||
|
# tidb vector not support 'CREATE/ADD INDEX' now
|
||||||
|
session.commit()
|
||||||
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
table = self._table(len(embeddings[0]))
|
||||||
|
ids = self._get_uuids(documents)
|
||||||
|
metas = [d.metadata for d in documents]
|
||||||
|
texts = [d.page_content for d in documents]
|
||||||
|
|
||||||
|
chunks_table_data = []
|
||||||
|
with self._engine.connect() as conn:
|
||||||
|
with conn.begin():
|
||||||
|
for id, text, meta, embedding in zip(
|
||||||
|
ids, texts, metas, embeddings
|
||||||
|
):
|
||||||
|
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
|
||||||
|
|
||||||
|
# Execute the batch insert when the batch size is reached
|
||||||
|
if len(chunks_table_data) == 500:
|
||||||
|
conn.execute(insert(table).values(chunks_table_data))
|
||||||
|
# Clear the chunks_table_data list for the next batch
|
||||||
|
chunks_table_data.clear()
|
||||||
|
|
||||||
|
# Insert any remaining records that didn't make up a full batch
|
||||||
|
if chunks_table_data:
|
||||||
|
conn.execute(insert(table).values(chunks_table_data))
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def text_exists(self, id: str) -> bool:
|
||||||
|
result = self.get_ids_by_metadata_field('doc_id', id)
|
||||||
|
return len(result) > 0
|
||||||
|
|
||||||
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
with Session(self._engine) as session:
|
||||||
|
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
|
||||||
|
select_statement = sql_text(
|
||||||
|
f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """
|
||||||
|
)
|
||||||
|
result = session.execute(select_statement).fetchall()
|
||||||
|
if result:
|
||||||
|
ids = [item[0] for item in result]
|
||||||
|
self._delete_by_ids(ids)
|
||||||
|
|
||||||
|
def _delete_by_ids(self, ids: list[str]) -> bool:
|
||||||
|
if ids is None:
|
||||||
|
raise ValueError("No ids provided to delete.")
|
||||||
|
table = self._table(self._dimension)
|
||||||
|
try:
|
||||||
|
with self._engine.connect() as conn:
|
||||||
|
with conn.begin():
|
||||||
|
delete_condition = table.c.id.in_(ids)
|
||||||
|
conn.execute(table.delete().where(delete_condition))
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print("Delete operation failed:", str(e))
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_by_document_id(self, document_id: str):
|
||||||
|
ids = self.get_ids_by_metadata_field('document_id', document_id)
|
||||||
|
if ids:
|
||||||
|
self._delete_by_ids(ids)
|
||||||
|
|
||||||
|
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||||
|
with Session(self._engine) as session:
|
||||||
|
select_statement = sql_text(
|
||||||
|
f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.{key}' = '{value}'; """
|
||||||
|
)
|
||||||
|
result = session.execute(select_statement).fetchall()
|
||||||
|
if result:
|
||||||
|
return [item[0] for item in result]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||||
|
ids = self.get_ids_by_metadata_field(key, value)
|
||||||
|
if ids:
|
||||||
|
self._delete_by_ids(ids)
|
||||||
|
|
||||||
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
|
top_k = kwargs.get("top_k", 5)
|
||||||
|
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||||
|
filter = kwargs.get('filter')
|
||||||
|
distance = 1 - score_threshold
|
||||||
|
|
||||||
|
query_vector_str = ", ".join(format(x) for x in query_vector)
|
||||||
|
query_vector_str = "[" + query_vector_str + "]"
|
||||||
|
logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}")
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
if self._distance_func == 'l2':
|
||||||
|
tidb_func = 'Vec_l2_distance'
|
||||||
|
elif self._distance_func == 'l2':
|
||||||
|
tidb_func = 'Vec_Cosine_distance'
|
||||||
|
else:
|
||||||
|
tidb_func = 'Vec_Cosine_distance'
|
||||||
|
|
||||||
|
with Session(self._engine) as session:
|
||||||
|
select_statement = sql_text(
|
||||||
|
f"""SELECT meta, text FROM (
|
||||||
|
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
|
||||||
|
FROM {self._collection_name}
|
||||||
|
ORDER BY distance
|
||||||
|
LIMIT {top_k}
|
||||||
|
) t WHERE distance < {distance};"""
|
||||||
|
)
|
||||||
|
res = session.execute(select_statement)
|
||||||
|
results = [(row[0], row[1]) for row in res]
|
||||||
|
for meta, text in results:
|
||||||
|
docs.append(Document(page_content=text, metadata=json.loads(meta)))
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
|
# tidb doesn't support bm25 search
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
with Session(self._engine) as session:
|
||||||
|
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
||||||
|
session.commit()
|
@@ -187,6 +187,31 @@ class Vector:
|
|||||||
database=config.get("PGVECTOR_DATABASE"),
|
database=config.get("PGVECTOR_DATABASE"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
elif vector_type == "tidb_vector":
|
||||||
|
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
|
||||||
|
|
||||||
|
if self._dataset.index_struct_dict:
|
||||||
|
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
||||||
|
collection_name = class_prefix.lower()
|
||||||
|
else:
|
||||||
|
dataset_id = self._dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||||
|
index_struct_dict = {
|
||||||
|
"type": 'tidb_vector',
|
||||||
|
"vector_store": {"class_prefix": collection_name}
|
||||||
|
}
|
||||||
|
self._dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
|
|
||||||
|
return TiDBVector(
|
||||||
|
collection_name=collection_name,
|
||||||
|
config=TiDBVectorConfig(
|
||||||
|
host=config.get('TIDB_VECTOR_HOST'),
|
||||||
|
port=config.get('TIDB_VECTOR_PORT'),
|
||||||
|
user=config.get('TIDB_VECTOR_USER'),
|
||||||
|
password=config.get('TIDB_VECTOR_PASSWORD'),
|
||||||
|
database=config.get('TIDB_VECTOR_DATABASE'),
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||||
|
|
||||||
|
@@ -81,5 +81,7 @@ pgvecto-rs==0.1.4
|
|||||||
firecrawl-py==0.0.5
|
firecrawl-py==0.0.5
|
||||||
oss2==2.18.5
|
oss2==2.18.5
|
||||||
pgvector==0.2.5
|
pgvector==0.2.5
|
||||||
|
pymysql==1.1.1
|
||||||
|
tidb-vector==0.0.9
|
||||||
google-cloud-aiplatform==1.49.0
|
google-cloud-aiplatform==1.49.0
|
||||||
vanna[postgres,mysql,clickhouse,duckdb]==0.5.5
|
vanna[postgres,mysql,clickhouse,duckdb]==0.5.5
|
||||||
|
@@ -0,0 +1,63 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
|
||||||
|
from models.dataset import Document
|
||||||
|
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tidb_vector():
|
||||||
|
return TiDBVector(
|
||||||
|
collection_name='test_collection',
|
||||||
|
config=TiDBVectorConfig(
|
||||||
|
host="xxx.eu-central-1.xxx.aws.tidbcloud.com",
|
||||||
|
port="4000",
|
||||||
|
user="xxx.root",
|
||||||
|
password="xxxxxx",
|
||||||
|
database="dify"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TiDBVectorTest(AbstractVectorTest):
|
||||||
|
def __init__(self, vector):
|
||||||
|
super().__init__()
|
||||||
|
self.vector = vector
|
||||||
|
|
||||||
|
def text_exists(self):
|
||||||
|
exist = self.vector.text_exists(self.example_doc_id)
|
||||||
|
assert exist == False
|
||||||
|
|
||||||
|
def search_by_vector(self):
|
||||||
|
hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||||
|
assert len(hits_by_vector) == 0
|
||||||
|
|
||||||
|
def search_by_full_text(self):
|
||||||
|
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||||
|
assert len(hits_by_full_text) == 0
|
||||||
|
|
||||||
|
def get_ids_by_metadata_field(self):
|
||||||
|
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
|
||||||
|
assert len(ids) == 0
|
||||||
|
|
||||||
|
def delete_by_document_id(self):
|
||||||
|
self.vector.delete_by_document_id(document_id=self.example_doc_id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_session):
|
||||||
|
TiDBVectorTest(vector=tidb_vector).run_all_tests()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session():
|
||||||
|
with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.Session', new_callable=MagicMock) as mock_session:
|
||||||
|
yield mock_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def setup_tidbvector_mock(tidb_vector, mock_session):
|
||||||
|
with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine'):
|
||||||
|
with patch.object(tidb_vector._engine, 'connect'):
|
||||||
|
yield tidb_vector
|
@@ -134,6 +134,12 @@ services:
|
|||||||
PGVECTOR_USER: postgres
|
PGVECTOR_USER: postgres
|
||||||
PGVECTOR_PASSWORD: difyai123456
|
PGVECTOR_PASSWORD: difyai123456
|
||||||
PGVECTOR_DATABASE: dify
|
PGVECTOR_DATABASE: dify
|
||||||
|
# tidb vector configurations
|
||||||
|
TIDB_VECTOR_HOST: tidb
|
||||||
|
TIDB_VECTOR_PORT: 4000
|
||||||
|
TIDB_VECTOR_USER: xxx.root
|
||||||
|
TIDB_VECTOR_PASSWORD: xxxxxx
|
||||||
|
TIDB_VECTOR_DATABASE: dify
|
||||||
# Mail configuration, support: resend, smtp
|
# Mail configuration, support: resend, smtp
|
||||||
MAIL_TYPE: ''
|
MAIL_TYPE: ''
|
||||||
# default send from email address, if not specified
|
# default send from email address, if not specified
|
||||||
@@ -289,6 +295,12 @@ services:
|
|||||||
PGVECTOR_USER: postgres
|
PGVECTOR_USER: postgres
|
||||||
PGVECTOR_PASSWORD: difyai123456
|
PGVECTOR_PASSWORD: difyai123456
|
||||||
PGVECTOR_DATABASE: dify
|
PGVECTOR_DATABASE: dify
|
||||||
|
# tidb vector configurations
|
||||||
|
TIDB_VECTOR_HOST: tidb
|
||||||
|
TIDB_VECTOR_PORT: 4000
|
||||||
|
TIDB_VECTOR_USER: xxx.root
|
||||||
|
TIDB_VECTOR_PASSWORD: xxxxxx
|
||||||
|
TIDB_VECTOR_DATABASE: dify
|
||||||
# Notion import configuration, support public and internal
|
# Notion import configuration, support public and internal
|
||||||
NOTION_INTEGRATION_TYPE: public
|
NOTION_INTEGRATION_TYPE: public
|
||||||
NOTION_CLIENT_SECRET: you-client-secret
|
NOTION_CLIENT_SECRET: you-client-secret
|
||||||
|
Reference in New Issue
Block a user