optimize lindorm vdb add_texts (#17212)

Co-authored-by: jiangzhijie <jiangzhijie.jzj@alibaba-inc.com>
This commit is contained in:
Jiang
2025-04-01 11:06:35 +08:00
committed by GitHub
parent ef1c1a12d2
commit ff388fe3e6

View File

@@ -1,10 +1,12 @@
import copy import copy
import json import json
import logging import logging
import time
from typing import Any, Optional from typing import Any, Optional
from opensearchpy import OpenSearch from opensearchpy import OpenSearch
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_exponential
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
@@ -77,31 +79,74 @@ class LindormVectorStore(BaseVector):
def refresh(self): def refresh(self):
self._client.indices.refresh(index=self._collection_name) self._client.indices.refresh(index=self._collection_name)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(
actions = [] self,
documents: list[Document],
embeddings: list[list[float]],
batch_size: int = 64,
timeout: int = 60,
**kwargs,
):
logger.info(f"Total documents to add: {len(documents)}")
uuids = self._get_uuids(documents) uuids = self._get_uuids(documents)
for i in range(len(documents)):
action_header = { total_docs = len(documents)
"index": { num_batches = (total_docs + batch_size - 1) // batch_size
"_index": self.collection_name.lower(),
"_id": uuids[i], @retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
)
def _bulk_with_retry(actions):
try:
response = self._client.bulk(actions, timeout=timeout)
if response["errors"]:
error_items = [item for item in response["items"] if "error" in item["index"]]
error_msg = f"Bulk indexing had {len(error_items)} errors"
logger.exception(error_msg)
raise Exception(error_msg)
return response
except Exception:
logger.exception("Bulk indexing error")
raise
for batch_num in range(num_batches):
start_idx = batch_num * batch_size
end_idx = min((batch_num + 1) * batch_size, total_docs)
actions = []
for i in range(start_idx, end_idx):
action_header = {
"index": {
"_index": self.collection_name.lower(),
"_id": uuids[i],
}
} }
} action_values: dict[str, Any] = {
action_values: dict[str, Any] = { Field.CONTENT_KEY.value: documents[i].page_content,
Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i],
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here Field.METADATA_KEY.value: documents[i].metadata,
Field.METADATA_KEY.value: documents[i].metadata, }
} if self._using_ugc:
if self._using_ugc: action_header["index"]["routing"] = self._routing
action_header["index"]["routing"] = self._routing if self._routing_field is not None:
if self._routing_field is not None: action_values[self._routing_field] = self._routing
action_values[self._routing_field] = self._routing
actions.append(action_header) actions.append(action_header)
actions.append(action_values) actions.append(action_values)
response = self._client.bulk(actions)
if response["errors"]: logger.info(f"Processing batch {batch_num + 1}/{num_batches} (documents {start_idx + 1} to {end_idx})")
for item in response["items"]:
print(f"{item['index']['status']}: {item['index']['error']['type']}") try:
_bulk_with_retry(actions)
logger.info(f"Successfully processed batch {batch_num + 1}")
# simple latency to avoid too many requests in a short time
if batch_num < num_batches - 1:
time.sleep(1)
except Exception:
logger.exception(f"Failed to process batch {batch_num + 1}")
raise
def get_ids_by_metadata_field(self, key: str, value: str): def get_ids_by_metadata_field(self, key: str, value: str):
query: dict[str, Any] = { query: dict[str, Any] = {
@@ -130,7 +175,6 @@ class LindormVectorStore(BaseVector):
if self._using_ugc: if self._using_ugc:
params["routing"] = self._routing params["routing"] = self._routing
self._client.delete(index=self._collection_name, id=id, params=params) self._client.delete(index=self._collection_name, id=id, params=params)
self.refresh()
else: else:
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")