optimize lindorm vdb add_texts (#17212)
Co-authored-by: jiangzhijie <jiangzhijie.jzj@alibaba-inc.com>
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from opensearchpy import OpenSearch
|
||||
from pydantic import BaseModel, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@@ -77,31 +79,74 @@ class LindormVectorStore(BaseVector):
|
||||
def refresh(self):
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
actions = []
|
||||
def add_texts(
|
||||
self,
|
||||
documents: list[Document],
|
||||
embeddings: list[list[float]],
|
||||
batch_size: int = 64,
|
||||
timeout: int = 60,
|
||||
**kwargs,
|
||||
):
|
||||
logger.info(f"Total documents to add: {len(documents)}")
|
||||
uuids = self._get_uuids(documents)
|
||||
for i in range(len(documents)):
|
||||
action_header = {
|
||||
"index": {
|
||||
"_index": self.collection_name.lower(),
|
||||
"_id": uuids[i],
|
||||
|
||||
total_docs = len(documents)
|
||||
num_batches = (total_docs + batch_size - 1) // batch_size
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
)
|
||||
def _bulk_with_retry(actions):
|
||||
try:
|
||||
response = self._client.bulk(actions, timeout=timeout)
|
||||
if response["errors"]:
|
||||
error_items = [item for item in response["items"] if "error" in item["index"]]
|
||||
error_msg = f"Bulk indexing had {len(error_items)} errors"
|
||||
logger.exception(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return response
|
||||
except Exception:
|
||||
logger.exception("Bulk indexing error")
|
||||
raise
|
||||
|
||||
for batch_num in range(num_batches):
|
||||
start_idx = batch_num * batch_size
|
||||
end_idx = min((batch_num + 1) * batch_size, total_docs)
|
||||
|
||||
actions = []
|
||||
for i in range(start_idx, end_idx):
|
||||
action_header = {
|
||||
"index": {
|
||||
"_index": self.collection_name.lower(),
|
||||
"_id": uuids[i],
|
||||
}
|
||||
}
|
||||
}
|
||||
action_values: dict[str, Any] = {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
}
|
||||
if self._using_ugc:
|
||||
action_header["index"]["routing"] = self._routing
|
||||
if self._routing_field is not None:
|
||||
action_values[self._routing_field] = self._routing
|
||||
actions.append(action_header)
|
||||
actions.append(action_values)
|
||||
response = self._client.bulk(actions)
|
||||
if response["errors"]:
|
||||
for item in response["items"]:
|
||||
print(f"{item['index']['status']}: {item['index']['error']['type']}")
|
||||
action_values: dict[str, Any] = {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
}
|
||||
if self._using_ugc:
|
||||
action_header["index"]["routing"] = self._routing
|
||||
if self._routing_field is not None:
|
||||
action_values[self._routing_field] = self._routing
|
||||
|
||||
actions.append(action_header)
|
||||
actions.append(action_values)
|
||||
|
||||
logger.info(f"Processing batch {batch_num + 1}/{num_batches} (documents {start_idx + 1} to {end_idx})")
|
||||
|
||||
try:
|
||||
_bulk_with_retry(actions)
|
||||
logger.info(f"Successfully processed batch {batch_num + 1}")
|
||||
# simple latency to avoid too many requests in a short time
|
||||
if batch_num < num_batches - 1:
|
||||
time.sleep(1)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process batch {batch_num + 1}")
|
||||
raise
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query: dict[str, Any] = {
|
||||
@@ -130,7 +175,6 @@ class LindormVectorStore(BaseVector):
|
||||
if self._using_ugc:
|
||||
params["routing"] = self._routing
|
||||
self._client.delete(index=self._collection_name, id=id, params=params)
|
||||
self.refresh()
|
||||
else:
|
||||
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
|
||||
|
||||
|
Reference in New Issue
Block a user