feat: tencent vectordb: use grpc client and set upsert batch size (#16016)

Co-authored-by: wlleiiwang <wlleiiwang@tencent.com>
This commit is contained in:
wlleiiwang
2025-03-27 12:20:16 +08:00
committed by GitHub
parent c23135c9e8
commit a743d5dc71
4 changed files with 286 additions and 171 deletions

View File

@@ -1,8 +1,9 @@
import json
import math
from typing import Any, Optional
from pydantic import BaseModel
from tcvectordb import VectorDBClient # type: ignore
from tcvectordb import RPCVectorDBClient, VectorDBException # type: ignore
from tcvectordb.model import document, enum # type: ignore
from tcvectordb.model import index as vdb_index # type: ignore
from tcvectordb.model.document import Filter # type: ignore
@@ -27,6 +28,7 @@ class TencentConfig(BaseModel):
metric_type: str = "L2"
shard: int = 1
replicas: int = 2
max_upsert_batch_size: int = 128
def to_tencent_params(self):
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
@@ -41,19 +43,10 @@ class TencentVector(BaseVector):
def __init__(self, collection_name: str, config: TencentConfig):
super().__init__(collection_name)
self._client_config = config
self._client = VectorDBClient(**self._client_config.to_tencent_params())
self._db = self._init_database()
self._client = RPCVectorDBClient(**self._client_config.to_tencent_params())
def _init_database(self):
exists = False
for db in self._client.list_databases():
if db.database_name == self._client_config.database:
exists = True
break
if exists:
return self._client.database(self._client_config.database)
else:
return self._client.create_database(database_name=self._client_config.database)
return self._client.create_database_if_not_exists(database_name=self._client_config.database)
def get_type(self) -> str:
return VectorType.TENCENT
@@ -62,8 +55,11 @@ class TencentVector(BaseVector):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def _has_collection(self) -> bool:
collections = self._db.list_collections()
return any(collection.collection_name == self._collection_name for collection in collections)
return bool(
self._client.exists_collection(
database_name=self._client_config.database, collection_name=self.collection_name
)
)
def _create_collection(self, dimension: int) -> None:
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
@@ -75,7 +71,6 @@ class TencentVector(BaseVector):
if self._has_collection():
return
self.delete()
index_type = None
for k, v in enum.IndexType.__members__.items():
if k == self._client_config.index_type:
@@ -99,16 +94,41 @@ class TencentVector(BaseVector):
params,
),
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER),
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER),
)
self._db.create_collection(
name=self._collection_name,
shard=self._client_config.shard,
replicas=self._client_config.replicas,
description="Collection for Dify",
index=index,
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER),
)
try:
self._client.create_collection(
database_name=self._client_config.database,
collection_name=self._collection_name,
shard=self._client_config.shard,
replicas=self._client_config.replicas,
description="Collection for Dify",
index=index,
)
except VectorDBException as e:
if "fieldType:json" not in e.message:
raise e
# vdb version not support json, use string
index = vdb_index.Index(
vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
vdb_index.VectorIndex(
self.field_vector,
dimension,
index_type,
metric_type,
params,
),
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER),
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER),
)
self._client.create_collection(
database_name=self._client_config.database,
collection_name=self._collection_name,
shard=self._client_config.shard,
replicas=self._client_config.replicas,
description="Collection for Dify",
index=index,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
@@ -119,22 +139,34 @@ class TencentVector(BaseVector):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
total_count = len(embeddings)
docs = []
for i in range(0, total_count):
if metadatas is None:
continue
metadata = metadatas[i] or {}
doc = document.Document(
id=metadata.get("doc_id"),
vector=embeddings[i],
text=texts[i],
metadata=json.dumps(metadata),
batch_size = self._client_config.max_upsert_batch_size
batch = math.ceil(total_count / batch_size)
for j in range(batch):
docs = []
start_idx = j * batch_size
end_idx = min(total_count, (j + 1) * batch_size)
for i in range(start_idx, end_idx):
if metadatas is None:
continue
metadata = metadatas[i] or {}
doc = document.Document(
id=metadata.get("doc_id"),
vector=embeddings[i],
text=texts[i],
metadata=metadata,
)
docs.append(doc)
self._client.upsert(
database_name=self._client_config.database,
collection_name=self.collection_name,
documents=docs,
timeout=self._client_config.timeout,
)
docs.append(doc)
self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout)
def text_exists(self, id: str) -> bool:
docs = self._db.collection(self._collection_name).query(document_ids=[id])
docs = self._client.query(
database_name=self._client_config.database, collection_name=self.collection_name, document_ids=[id]
)
if docs and len(docs) > 0:
return True
return False
@@ -142,17 +174,25 @@ class TencentVector(BaseVector):
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._db.collection(self._collection_name).delete(document_ids=ids)
self._client.delete(
database_name=self._client_config.database, collection_name=self.collection_name, document_ids=ids
)
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(f"metadata.{key}", [value])))
self._client.delete(
database_name=self._client_config.database,
collection_name=self.collection_name,
filter=Filter(Filter.In(f"metadata.{key}", [value])),
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
res = self._db.collection(self._collection_name).search(
res = self._client.search(
database_name=self._client_config.database,
collection_name=self.collection_name,
vectors=[query_vector],
filter=filter,
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
@@ -173,8 +213,6 @@ class TencentVector(BaseVector):
for result in res[0]:
meta = result.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
score = 1 - result.get("score", 0.0)
if score > score_threshold:
meta["score"] = score
@@ -184,7 +222,7 @@ class TencentVector(BaseVector):
return docs
def delete(self) -> None:
self._db.drop_collection(name=self._collection_name)
self._client.drop_collection(database_name=self._client_config.database, collection_name=self.collection_name)
class TencentVectorFactory(AbstractVectorFactory):