optimize lindorm vdb add_texts (#17212)

Co-authored-by: jiangzhijie <jiangzhijie.jzj@alibaba-inc.com>
This commit is contained in:
Jiang
2025-04-01 11:06:35 +08:00
committed by GitHub
parent ef1c1a12d2
commit ff388fe3e6

View File

@@ -1,10 +1,12 @@
import copy
import json
import logging
import time
from typing import Any, Optional
from opensearchpy import OpenSearch
from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_exponential
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -77,31 +79,74 @@ class LindormVectorStore(BaseVector):
def refresh(self):
self._client.indices.refresh(index=self._collection_name)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
actions = []
def add_texts(
self,
documents: list[Document],
embeddings: list[list[float]],
batch_size: int = 64,
timeout: int = 60,
**kwargs,
):
logger.info(f"Total documents to add: {len(documents)}")
uuids = self._get_uuids(documents)
for i in range(len(documents)):
action_header = {
"index": {
"_index": self.collection_name.lower(),
"_id": uuids[i],
total_docs = len(documents)
num_batches = (total_docs + batch_size - 1) // batch_size
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
)
def _bulk_with_retry(actions):
try:
response = self._client.bulk(actions, timeout=timeout)
if response["errors"]:
error_items = [item for item in response["items"] if "error" in item["index"]]
error_msg = f"Bulk indexing had {len(error_items)} errors"
logger.exception(error_msg)
raise Exception(error_msg)
return response
except Exception:
logger.exception("Bulk indexing error")
raise
for batch_num in range(num_batches):
start_idx = batch_num * batch_size
end_idx = min((batch_num + 1) * batch_size, total_docs)
actions = []
for i in range(start_idx, end_idx):
action_header = {
"index": {
"_index": self.collection_name.lower(),
"_id": uuids[i],
}
}
}
action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
if self._routing_field is not None:
action_values[self._routing_field] = self._routing
actions.append(action_header)
actions.append(action_values)
response = self._client.bulk(actions)
if response["errors"]:
for item in response["items"]:
print(f"{item['index']['status']}: {item['index']['error']['type']}")
action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
if self._routing_field is not None:
action_values[self._routing_field] = self._routing
actions.append(action_header)
actions.append(action_values)
logger.info(f"Processing batch {batch_num + 1}/{num_batches} (documents {start_idx + 1} to {end_idx})")
try:
_bulk_with_retry(actions)
logger.info(f"Successfully processed batch {batch_num + 1}")
# simple latency to avoid too many requests in a short time
if batch_num < num_batches - 1:
time.sleep(1)
except Exception:
logger.exception(f"Failed to process batch {batch_num + 1}")
raise
def get_ids_by_metadata_field(self, key: str, value: str):
query: dict[str, Any] = {
@@ -130,7 +175,6 @@ class LindormVectorStore(BaseVector):
if self._using_ugc:
params["routing"] = self._routing
self._client.delete(index=self._collection_name, id=id, params=params)
self.refresh()
else:
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")