feat: support tencent vector db (#3568)

This commit is contained in:
quicksand
2024-06-14 19:25:17 +08:00
committed by GitHub
parent 9ed21737d5
commit 4080f7b8ad
16 changed files with 481 additions and 5 deletions

View File

@@ -0,0 +1,227 @@
import json
from typing import Any, Optional
from flask import current_app
from pydantic import BaseModel
from tcvectordb import VectorDBClient
from tcvectordb.model import document, enum
from tcvectordb.model import index as vdb_index
from tcvectordb.model.document import Filter
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
class TencentConfig(BaseModel):
url: str
api_key: Optional[str]
timeout: float = 30
username: Optional[str]
database: Optional[str]
index_type: str = "HNSW"
metric_type: str = "L2"
shard: int = 1,
replicas: int = 2,
def to_tencent_params(self):
return {
'url': self.url,
'username': self.username,
'key': self.api_key,
'timeout': self.timeout
}
class TencentVector(BaseVector):
field_id: str = "id"
field_vector: str = "vector"
field_text: str = "text"
field_metadata: str = "metadata"
def __init__(self, collection_name: str, config: TencentConfig):
super().__init__(collection_name)
self._client_config = config
self._client = VectorDBClient(**self._client_config.to_tencent_params())
self._db = self._init_database()
def _init_database(self):
exists = False
for db in self._client.list_databases():
if db.database_name == self._client_config.database:
exists = True
break
if exists:
return self._client.database(self._client_config.database)
else:
return self._client.create_database(database_name=self._client_config.database)
def get_type(self) -> str:
return 'tencent'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
def _has_collection(self) -> bool:
collections = self._db.list_collections()
for collection in collections:
if collection.collection_name == self._collection_name:
return True
return False
def _create_collection(self, dimension: int) -> None:
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
if self._has_collection():
return
self.delete()
index_type = None
for k, v in enum.IndexType.__members__.items():
if k == self._client_config.index_type:
index_type = v
if index_type is None:
raise ValueError("unsupported index_type")
metric_type = None
for k, v in enum.MetricType.__members__.items():
if k == self._client_config.metric_type:
metric_type = v
if metric_type is None:
raise ValueError("unsupported metric_type")
params = vdb_index.HNSWParams(m=16, efconstruction=200)
index = vdb_index.Index(
vdb_index.FilterIndex(
self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
),
vdb_index.VectorIndex(
self.field_vector,
dimension,
index_type,
metric_type,
params,
),
vdb_index.FilterIndex(
self.field_text, enum.FieldType.String, enum.IndexType.FILTER
),
vdb_index.FilterIndex(
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
),
)
self._db.create_collection(
name=self._collection_name,
shard=self._client_config.shard,
replicas=self._client_config.replicas,
description="Collection for Dify",
index=index,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._create_collection(len(embeddings[0]))
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
total_count = len(embeddings)
docs = []
for id in range(0, total_count):
if metadatas is None:
continue
metadata = json.dumps(metadatas[id])
doc = document.Document(
id=metadatas[id]["doc_id"],
vector=embeddings[id],
text=texts[id],
metadata=metadata,
)
docs.append(doc)
self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout)
def text_exists(self, id: str) -> bool:
docs = self._db.collection(self._collection_name).query(document_ids=[id])
if docs and len(docs) > 0:
return True
return False
def delete_by_ids(self, ids: list[str]) -> None:
self._db.collection(self._collection_name).delete(document_ids=ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
res = self._db.collection(self._collection_name).search(vectors=[query_vector],
params=document.HNSWSearchParams(
ef=kwargs.get("ef", 10)),
retrieve_vector=False,
limit=kwargs.get('top_k', 4),
timeout=self._client_config.timeout,
)
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def _get_search_res(self, res, score_threshold):
docs = []
if res is None or len(res) == 0:
return docs
for result in res[0]:
meta = result.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
score = 1 - result.get("score", 0.0)
if score > score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)
docs.append(doc)
return docs
def delete(self) -> None:
self._db.drop_collection(name=self._collection_name)
class TencentVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
config = current_app.config
return TencentVector(
collection_name=collection_name,
config=TencentConfig(
url=config.get('TENCENT_VECTOR_DB_URL'),
api_key=config.get('TENCENT_VECTOR_DB_API_KEY'),
timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'),
username=config.get('TENCENT_VECTOR_DB_USERNAME'),
database=config.get('TENCENT_VECTOR_DB_DATABASE'),
shard=config.get('TENCENT_VECTOR_DB_SHARD'),
replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'),
)
)

View File

@@ -39,7 +39,6 @@ class Vector:
def _init_vector(self) -> BaseVector:
config = current_app.config
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
@@ -76,6 +75,9 @@ class Vector:
case VectorType.WEAVIATE:
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
return WeaviateVectorFactory
case VectorType.TENCENT:
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
return TencentVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@@ -10,3 +10,4 @@ class VectorType(str, Enum):
RELYT = 'relyt'
TIDB_VECTOR = 'tidb_vector'
WEAVIATE = 'weaviate'
TENCENT = 'tencent'