From 0797f9bc05126d6280ad38177777bac96ddbd05d Mon Sep 17 00:00:00 2001 From: Weaxs <459312872@qq.com> Date: Wed, 5 Jun 2024 18:19:53 +0800 Subject: [PATCH] feat: support tidb vector (#4588) --- api/.env.example | 7 + api/config.py | 7 + api/controllers/console/datasets/datasets.py | 4 +- .../datasource/vdb/tidb_vector/__init__.py | 0 .../datasource/vdb/tidb_vector/tidb_vector.py | 214 ++++++++++++++++++ api/core/rag/datasource/vdb/vector_factory.py | 25 ++ api/requirements.txt | 2 + .../vdb/tidb_vector/__init__.py | 0 .../vdb/tidb_vector/test_tidb_vector.py | 63 ++++++ docker/docker-compose.yaml | 12 + 10 files changed, 332 insertions(+), 2 deletions(-) create mode 100644 api/core/rag/datasource/vdb/tidb_vector/__init__.py create mode 100644 api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py create mode 100644 api/tests/integration_tests/vdb/tidb_vector/__init__.py create mode 100644 api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py diff --git a/api/.env.example b/api/.env.example index 183c22833..f112721a7 100644 --- a/api/.env.example +++ b/api/.env.example @@ -112,6 +112,13 @@ PGVECTOR_USER=postgres PGVECTOR_PASSWORD=postgres PGVECTOR_DATABASE=postgres +# Tidb Vector configuration +TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com +TIDB_VECTOR_PORT=4000 +TIDB_VECTOR_USER=xxx.root +TIDB_VECTOR_PASSWORD=xxxxxx +TIDB_VECTOR_DATABASE=dify + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_BATCH_LIMIT=5 diff --git a/api/config.py b/api/config.py index 269abf0a6..286b3336a 100644 --- a/api/config.py +++ b/api/config.py @@ -299,6 +299,13 @@ class Config: self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD') self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE') + # tidb-vector settings + self.TIDB_VECTOR_HOST = get_env('TIDB_VECTOR_HOST') + self.TIDB_VECTOR_PORT = get_env('TIDB_VECTOR_PORT') + self.TIDB_VECTOR_USER = get_env('TIDB_VECTOR_USER') + self.TIDB_VECTOR_PASSWORD = get_env('TIDB_VECTOR_PASSWORD') + self.TIDB_VECTOR_DATABASE = get_env('TIDB_VECTOR_DATABASE') + # ------------------------ # Mail Configurations. # ------------------------ diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 30dc6ac84..72c4c0905 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -476,7 +476,7 @@ class DatasetRetrievalSettingApi(Resource): @account_initialization_required def get(self): vector_type = current_app.config['VECTOR_STORE'] - if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs"}: + if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs", 'tidb_vector'}: return { 'retrieval_method': [ 'semantic_search' @@ -497,7 +497,7 @@ class DatasetRetrievalSettingMockApi(Resource): @login_required @account_initialization_required def get(self, vector_type): - if vector_type in {'milvus', 'relyt', 'pgvector'}: + if vector_type in {'milvus', 'relyt', 'pgvector', 'tidb_vector'}: return { 'retrieval_method': [ 'semantic_search' diff --git a/api/core/rag/datasource/vdb/tidb_vector/__init__.py b/api/core/rag/datasource/vdb/tidb_vector/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py new file mode 100644 index 000000000..b22839db4 --- /dev/null +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -0,0 +1,214 @@ +import json +import logging +from typing import Any + +import sqlalchemy +from pydantic import BaseModel, root_validator +from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert +from sqlalchemy import text as sql_text +from sqlalchemy.orm import Session, declarative_base + +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) + + +class TiDBVectorConfig(BaseModel): + host: str + port: int + user: str + password: str + database: str + + @root_validator() + def validate_config(cls, values: dict) -> dict: + if not values['host']: + raise ValueError("config TIDB_VECTOR_HOST is required") + if not values['port']: + raise ValueError("config TIDB_VECTOR_PORT is required") + if not values['user']: + raise ValueError("config TIDB_VECTOR_USER is required") + if not values['password']: + raise ValueError("config TIDB_VECTOR_PASSWORD is required") + if not values['database']: + raise ValueError("config TIDB_VECTOR_DATABASE is required") + return values + + +class TiDBVector(BaseVector): + + def _table(self, dim: int) -> Table: + from tidb_vector.sqlalchemy import VectorType + return Table( + self._collection_name, + self._orm_base.metadata, + Column('id', String(36), primary_key=True, nullable=False), + Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"), + Column("text", TEXT, nullable=False), + Column("meta", JSON, nullable=False), + Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")), + Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")), + extend_existing=True + ) + + def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'): + super().__init__(collection_name) + self._client_config = config + self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?" + f"ssl_verify_cert=true&ssl_verify_identity=true") + self._distance_func = distance_func.lower() + self._engine = create_engine(self._url) + self._orm_base = declarative_base() + self._dimension = 1536 + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + logger.info("create collection and add texts, collection_name: " + self._collection_name) + self._create_collection(len(embeddings[0])) + self.add_texts(texts, embeddings) + self._dimension = len(embeddings[0]) + pass + + def _create_collection(self, dimension: int): + logger.info("_create_collection, collection_name " + self._collection_name) + lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return + with Session(self._engine) as session: + session.begin() + drop_statement = sql_text(f"""DROP TABLE IF EXISTS {self._collection_name}; """) + session.execute(drop_statement) + create_statement = sql_text(f""" + CREATE TABLE IF NOT EXISTS {self._collection_name} ( + id CHAR(36) PRIMARY KEY, + text TEXT NOT NULL, + meta JSON NOT NULL, + vector VECTOR({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})", + create_time DATETIME DEFAULT CURRENT_TIMESTAMP, + update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + ); + """) + session.execute(create_statement) + # tidb vector not support 'CREATE/ADD INDEX' now + session.commit() + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + table = self._table(len(embeddings[0])) + ids = self._get_uuids(documents) + metas = [d.metadata for d in documents] + texts = [d.page_content for d in documents] + + chunks_table_data = [] + with self._engine.connect() as conn: + with conn.begin(): + for id, text, meta, embedding in zip( + ids, texts, metas, embeddings + ): + chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) + + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == 500: + conn.execute(insert(table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: + conn.execute(insert(table).values(chunks_table_data)) + return ids + + def text_exists(self, id: str) -> bool: + result = self.get_ids_by_metadata_field('doc_id', id) + return len(result) > 0 + + def delete_by_ids(self, ids: list[str]) -> None: + with Session(self._engine) as session: + ids_str = ','.join(f"'{doc_id}'" for doc_id in ids) + select_statement = sql_text( + f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """ + ) + result = session.execute(select_statement).fetchall() + if result: + ids = [item[0] for item in result] + self._delete_by_ids(ids) + + def _delete_by_ids(self, ids: list[str]) -> bool: + if ids is None: + raise ValueError("No ids provided to delete.") + table = self._table(self._dimension) + try: + with self._engine.connect() as conn: + with conn.begin(): + delete_condition = table.c.id.in_(ids) + conn.execute(table.delete().where(delete_condition)) + return True + except Exception as e: + print("Delete operation failed:", str(e)) + return False + + def delete_by_document_id(self, document_id: str): + ids = self.get_ids_by_metadata_field('document_id', document_id) + if ids: + self._delete_by_ids(ids) + + def get_ids_by_metadata_field(self, key: str, value: str): + with Session(self._engine) as session: + select_statement = sql_text( + f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.{key}' = '{value}'; """ + ) + result = session.execute(select_statement).fetchall() + if result: + return [item[0] for item in result] + else: + return None + + def delete_by_metadata_field(self, key: str, value: str) -> None: + ids = self.get_ids_by_metadata_field(key, value) + if ids: + self._delete_by_ids(ids) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + top_k = kwargs.get("top_k", 5) + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + filter = kwargs.get('filter') + distance = 1 - score_threshold + + query_vector_str = ", ".join(format(x) for x in query_vector) + query_vector_str = "[" + query_vector_str + "]" + logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}") + + docs = [] + if self._distance_func == 'l2': + tidb_func = 'Vec_l2_distance' + elif self._distance_func == 'l2': + tidb_func = 'Vec_Cosine_distance' + else: + tidb_func = 'Vec_Cosine_distance' + + with Session(self._engine) as session: + select_statement = sql_text( + f"""SELECT meta, text FROM ( + SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance + FROM {self._collection_name} + ORDER BY distance + LIMIT {top_k} + ) t WHERE distance < {distance};""" + ) + res = session.execute(select_statement) + results = [(row[0], row[1]) for row in res] + for meta, text in results: + docs.append(Document(page_content=text, metadata=json.loads(meta))) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # tidb doesn't support bm25 search + return [] + + def delete(self) -> None: + with Session(self._engine) as session: + session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};""")) + session.commit() diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 82ba6139e..b500b37d6 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -187,6 +187,31 @@ class Vector: database=config.get("PGVECTOR_DATABASE"), ), ) + elif vector_type == "tidb_vector": + from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig + + if self._dataset.index_struct_dict: + class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix.lower() + else: + dataset_id = self._dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + index_struct_dict = { + "type": 'tidb_vector', + "vector_store": {"class_prefix": collection_name} + } + self._dataset.index_struct = json.dumps(index_struct_dict) + + return TiDBVector( + collection_name=collection_name, + config=TiDBVectorConfig( + host=config.get('TIDB_VECTOR_HOST'), + port=config.get('TIDB_VECTOR_PORT'), + user=config.get('TIDB_VECTOR_USER'), + password=config.get('TIDB_VECTOR_PASSWORD'), + database=config.get('TIDB_VECTOR_DATABASE'), + ), + ) else: raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") diff --git a/api/requirements.txt b/api/requirements.txt index cfd5fbbb4..1749b4a2d 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -81,5 +81,7 @@ pgvecto-rs==0.1.4 firecrawl-py==0.0.5 oss2==2.18.5 pgvector==0.2.5 +pymysql==1.1.1 +tidb-vector==0.0.9 google-cloud-aiplatform==1.49.0 vanna[postgres,mysql,clickhouse,duckdb]==0.5.5 diff --git a/api/tests/integration_tests/vdb/tidb_vector/__init__.py b/api/tests/integration_tests/vdb/tidb_vector/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py new file mode 100644 index 000000000..837a228a5 --- /dev/null +++ b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py @@ -0,0 +1,63 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig +from models.dataset import Document +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + + +@pytest.fixture +def tidb_vector(): + return TiDBVector( + collection_name='test_collection', + config=TiDBVectorConfig( + host="xxx.eu-central-1.xxx.aws.tidbcloud.com", + port="4000", + user="xxx.root", + password="xxxxxx", + database="dify" + ) + ) + + +class TiDBVectorTest(AbstractVectorTest): + def __init__(self, vector): + super().__init__() + self.vector = vector + + def text_exists(self): + exist = self.vector.text_exists(self.example_doc_id) + assert exist == False + + def search_by_vector(self): + hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 0 + + def search_by_full_text(self): + hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + assert len(ids) == 0 + + def delete_by_document_id(self): + self.vector.delete_by_document_id(document_id=self.example_doc_id) + + +def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_session): + TiDBVectorTest(vector=tidb_vector).run_all_tests() + + +@pytest.fixture +def mock_session(): + with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.Session', new_callable=MagicMock) as mock_session: + yield mock_session + + +@pytest.fixture +def setup_tidbvector_mock(tidb_vector, mock_session): + with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine'): + with patch.object(tidb_vector._engine, 'connect'): + yield tidb_vector diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 35fa32738..c3b543051 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -134,6 +134,12 @@ services: PGVECTOR_USER: postgres PGVECTOR_PASSWORD: difyai123456 PGVECTOR_DATABASE: dify + # tidb vector configurations + TIDB_VECTOR_HOST: tidb + TIDB_VECTOR_PORT: 4000 + TIDB_VECTOR_USER: xxx.root + TIDB_VECTOR_PASSWORD: xxxxxx + TIDB_VECTOR_DATABASE: dify # Mail configuration, support: resend, smtp MAIL_TYPE: '' # default send from email address, if not specified @@ -289,6 +295,12 @@ services: PGVECTOR_USER: postgres PGVECTOR_PASSWORD: difyai123456 PGVECTOR_DATABASE: dify + # tidb vector configurations + TIDB_VECTOR_HOST: tidb + TIDB_VECTOR_PORT: 4000 + TIDB_VECTOR_USER: xxx.root + TIDB_VECTOR_PASSWORD: xxxxxx + TIDB_VECTOR_DATABASE: dify # Notion import configuration, support public and internal NOTION_INTEGRATION_TYPE: public NOTION_CLIENT_SECRET: you-client-secret