feat: tencent vectordb: use grpc client and set upsert batch size (#16016)
Co-authored-by: wlleiiwang <wlleiiwang@tencent.com>
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
import math
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tcvectordb import VectorDBClient # type: ignore
|
||||
from tcvectordb import RPCVectorDBClient, VectorDBException # type: ignore
|
||||
from tcvectordb.model import document, enum # type: ignore
|
||||
from tcvectordb.model import index as vdb_index # type: ignore
|
||||
from tcvectordb.model.document import Filter # type: ignore
|
||||
@@ -27,6 +28,7 @@ class TencentConfig(BaseModel):
|
||||
metric_type: str = "L2"
|
||||
shard: int = 1
|
||||
replicas: int = 2
|
||||
max_upsert_batch_size: int = 128
|
||||
|
||||
def to_tencent_params(self):
|
||||
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
|
||||
@@ -41,19 +43,10 @@ class TencentVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: TencentConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = VectorDBClient(**self._client_config.to_tencent_params())
|
||||
self._db = self._init_database()
|
||||
self._client = RPCVectorDBClient(**self._client_config.to_tencent_params())
|
||||
|
||||
def _init_database(self):
|
||||
exists = False
|
||||
for db in self._client.list_databases():
|
||||
if db.database_name == self._client_config.database:
|
||||
exists = True
|
||||
break
|
||||
if exists:
|
||||
return self._client.database(self._client_config.database)
|
||||
else:
|
||||
return self._client.create_database(database_name=self._client_config.database)
|
||||
return self._client.create_database_if_not_exists(database_name=self._client_config.database)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.TENCENT
|
||||
@@ -62,8 +55,11 @@ class TencentVector(BaseVector):
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
|
||||
def _has_collection(self) -> bool:
|
||||
collections = self._db.list_collections()
|
||||
return any(collection.collection_name == self._collection_name for collection in collections)
|
||||
return bool(
|
||||
self._client.exists_collection(
|
||||
database_name=self._client_config.database, collection_name=self.collection_name
|
||||
)
|
||||
)
|
||||
|
||||
def _create_collection(self, dimension: int) -> None:
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
@@ -75,7 +71,6 @@ class TencentVector(BaseVector):
|
||||
if self._has_collection():
|
||||
return
|
||||
|
||||
self.delete()
|
||||
index_type = None
|
||||
for k, v in enum.IndexType.__members__.items():
|
||||
if k == self._client_config.index_type:
|
||||
@@ -99,16 +94,41 @@ class TencentVector(BaseVector):
|
||||
params,
|
||||
),
|
||||
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER),
|
||||
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER),
|
||||
)
|
||||
|
||||
self._db.create_collection(
|
||||
name=self._collection_name,
|
||||
shard=self._client_config.shard,
|
||||
replicas=self._client_config.replicas,
|
||||
description="Collection for Dify",
|
||||
index=index,
|
||||
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER),
|
||||
)
|
||||
try:
|
||||
self._client.create_collection(
|
||||
database_name=self._client_config.database,
|
||||
collection_name=self._collection_name,
|
||||
shard=self._client_config.shard,
|
||||
replicas=self._client_config.replicas,
|
||||
description="Collection for Dify",
|
||||
index=index,
|
||||
)
|
||||
except VectorDBException as e:
|
||||
if "fieldType:json" not in e.message:
|
||||
raise e
|
||||
# vdb version not support json, use string
|
||||
index = vdb_index.Index(
|
||||
vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
|
||||
vdb_index.VectorIndex(
|
||||
self.field_vector,
|
||||
dimension,
|
||||
index_type,
|
||||
metric_type,
|
||||
params,
|
||||
),
|
||||
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER),
|
||||
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER),
|
||||
)
|
||||
self._client.create_collection(
|
||||
database_name=self._client_config.database,
|
||||
collection_name=self._collection_name,
|
||||
shard=self._client_config.shard,
|
||||
replicas=self._client_config.replicas,
|
||||
description="Collection for Dify",
|
||||
index=index,
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
@@ -119,22 +139,34 @@ class TencentVector(BaseVector):
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
total_count = len(embeddings)
|
||||
docs = []
|
||||
for i in range(0, total_count):
|
||||
if metadatas is None:
|
||||
continue
|
||||
metadata = metadatas[i] or {}
|
||||
doc = document.Document(
|
||||
id=metadata.get("doc_id"),
|
||||
vector=embeddings[i],
|
||||
text=texts[i],
|
||||
metadata=json.dumps(metadata),
|
||||
batch_size = self._client_config.max_upsert_batch_size
|
||||
batch = math.ceil(total_count / batch_size)
|
||||
for j in range(batch):
|
||||
docs = []
|
||||
start_idx = j * batch_size
|
||||
end_idx = min(total_count, (j + 1) * batch_size)
|
||||
for i in range(start_idx, end_idx):
|
||||
if metadatas is None:
|
||||
continue
|
||||
metadata = metadatas[i] or {}
|
||||
doc = document.Document(
|
||||
id=metadata.get("doc_id"),
|
||||
vector=embeddings[i],
|
||||
text=texts[i],
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
self._client.upsert(
|
||||
database_name=self._client_config.database,
|
||||
collection_name=self.collection_name,
|
||||
documents=docs,
|
||||
timeout=self._client_config.timeout,
|
||||
)
|
||||
docs.append(doc)
|
||||
self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
docs = self._db.collection(self._collection_name).query(document_ids=[id])
|
||||
docs = self._client.query(
|
||||
database_name=self._client_config.database, collection_name=self.collection_name, document_ids=[id]
|
||||
)
|
||||
if docs and len(docs) > 0:
|
||||
return True
|
||||
return False
|
||||
@@ -142,17 +174,25 @@ class TencentVector(BaseVector):
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
self._db.collection(self._collection_name).delete(document_ids=ids)
|
||||
self._client.delete(
|
||||
database_name=self._client_config.database, collection_name=self.collection_name, document_ids=ids
|
||||
)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(f"metadata.{key}", [value])))
|
||||
self._client.delete(
|
||||
database_name=self._client_config.database,
|
||||
collection_name=self.collection_name,
|
||||
filter=Filter(Filter.In(f"metadata.{key}", [value])),
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = None
|
||||
if document_ids_filter:
|
||||
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
|
||||
res = self._db.collection(self._collection_name).search(
|
||||
res = self._client.search(
|
||||
database_name=self._client_config.database,
|
||||
collection_name=self.collection_name,
|
||||
vectors=[query_vector],
|
||||
filter=filter,
|
||||
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
|
||||
@@ -173,8 +213,6 @@ class TencentVector(BaseVector):
|
||||
|
||||
for result in res[0]:
|
||||
meta = result.get(self.field_metadata)
|
||||
if meta is not None:
|
||||
meta = json.loads(meta)
|
||||
score = 1 - result.get("score", 0.0)
|
||||
if score > score_threshold:
|
||||
meta["score"] = score
|
||||
@@ -184,7 +222,7 @@ class TencentVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
self._db.drop_collection(name=self._collection_name)
|
||||
self._client.drop_collection(database_name=self._client_config.database, collection_name=self.collection_name)
|
||||
|
||||
|
||||
class TencentVectorFactory(AbstractVectorFactory):
|
||||
|
Reference in New Issue
Block a user