fix: tablestore vdb support metadata filter (#22774)
Co-authored-by: xiaozhiqing.xzq <xiaozhiqing.xzq@alibaba-inc.com>
This commit is contained in:
@@ -118,10 +118,21 @@ class TableStoreVector(BaseVector):
|
|||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
top_k = kwargs.get("top_k", 4)
|
top_k = kwargs.get("top_k", 4)
|
||||||
return self._search_by_vector(query_vector, top_k)
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
filtered_list = None
|
||||||
|
if document_ids_filter:
|
||||||
|
filtered_list = ["document_id=" + item for item in document_ids_filter]
|
||||||
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
|
return self._search_by_vector(query_vector, filtered_list, top_k, score_threshold)
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
return self._search_by_full_text(query)
|
top_k = kwargs.get("top_k", 4)
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
filtered_list = None
|
||||||
|
if document_ids_filter:
|
||||||
|
filtered_list = ["document_id=" + item for item in document_ids_filter]
|
||||||
|
|
||||||
|
return self._search_by_full_text(query, filtered_list, top_k)
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
self._delete_table_if_exist()
|
self._delete_table_if_exist()
|
||||||
@@ -230,32 +241,51 @@ class TableStoreVector(BaseVector):
|
|||||||
primary_key = [("id", id)]
|
primary_key = [("id", id)]
|
||||||
row = tablestore.Row(primary_key)
|
row = tablestore.Row(primary_key)
|
||||||
self._tablestore_client.delete_row(self._table_name, row, None)
|
self._tablestore_client.delete_row(self._table_name, row, None)
|
||||||
logging.info("Tablestore delete row successfully. id:%s", id)
|
|
||||||
|
|
||||||
def _search_by_metadata(self, key: str, value: str) -> list[str]:
|
def _search_by_metadata(self, key: str, value: str) -> list[str]:
|
||||||
query = tablestore.SearchQuery(
|
query = tablestore.SearchQuery(
|
||||||
tablestore.TermQuery(self._tags_field, str(key) + "=" + str(value)),
|
tablestore.TermQuery(self._tags_field, str(key) + "=" + str(value)),
|
||||||
limit=100,
|
limit=1000,
|
||||||
get_total_count=False,
|
get_total_count=False,
|
||||||
)
|
)
|
||||||
|
rows: list[str] = []
|
||||||
|
next_token = None
|
||||||
|
while True:
|
||||||
|
if next_token is not None:
|
||||||
|
query.next_token = next_token
|
||||||
|
|
||||||
search_response = self._tablestore_client.search(
|
search_response = self._tablestore_client.search(
|
||||||
table_name=self._table_name,
|
table_name=self._table_name,
|
||||||
index_name=self._index_name,
|
index_name=self._index_name,
|
||||||
search_query=query,
|
search_query=query,
|
||||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
columns_to_get=tablestore.ColumnsToGet(
|
||||||
)
|
column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
return [row[0][0][1] for row in search_response.rows]
|
if search_response is not None:
|
||||||
|
rows.extend([row[0][0][1] for row in search_response.rows])
|
||||||
|
|
||||||
def _search_by_vector(self, query_vector: list[float], top_k: int) -> list[Document]:
|
if search_response is None or search_response.next_token == b"":
|
||||||
ots_query = tablestore.KnnVectorQuery(
|
break
|
||||||
|
else:
|
||||||
|
next_token = search_response.next_token
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
|
def _search_by_vector(
|
||||||
|
self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float
|
||||||
|
) -> list[Document]:
|
||||||
|
knn_vector_query = tablestore.KnnVectorQuery(
|
||||||
field_name=Field.VECTOR.value,
|
field_name=Field.VECTOR.value,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
float32_query_vector=query_vector,
|
float32_query_vector=query_vector,
|
||||||
)
|
)
|
||||||
|
if document_ids_filter:
|
||||||
|
knn_vector_query.filter = tablestore.TermsQuery(self._tags_field, document_ids_filter)
|
||||||
|
|
||||||
sort = tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)])
|
sort = tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)])
|
||||||
search_query = tablestore.SearchQuery(ots_query, limit=top_k, get_total_count=False, sort=sort)
|
search_query = tablestore.SearchQuery(knn_vector_query, limit=top_k, get_total_count=False, sort=sort)
|
||||||
|
|
||||||
search_response = self._tablestore_client.search(
|
search_response = self._tablestore_client.search(
|
||||||
table_name=self._table_name,
|
table_name=self._table_name,
|
||||||
@@ -263,30 +293,32 @@ class TableStoreVector(BaseVector):
|
|||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||||
)
|
)
|
||||||
logging.info(
|
|
||||||
"Tablestore search successfully. request_id:%s",
|
|
||||||
search_response.request_id,
|
|
||||||
)
|
|
||||||
return self._to_query_result(search_response)
|
|
||||||
|
|
||||||
def _to_query_result(self, search_response: tablestore.SearchResponse) -> list[Document]:
|
|
||||||
documents = []
|
documents = []
|
||||||
for row in search_response.rows:
|
for search_hit in search_response.search_hits:
|
||||||
documents.append(
|
if search_hit.score > score_threshold:
|
||||||
Document(
|
metadata = json.loads(search_hit.row[1][0][1])
|
||||||
page_content=row[1][2][1],
|
metadata["score"] = search_hit.score
|
||||||
vector=json.loads(row[1][3][1]),
|
documents.append(
|
||||||
metadata=json.loads(row[1][0][1]),
|
Document(
|
||||||
|
page_content=search_hit.row[1][2][1],
|
||||||
|
vector=json.loads(search_hit.row[1][3][1]),
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def _search_by_full_text(self, query: str) -> list[Document]:
|
def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]:
|
||||||
|
bool_query = tablestore.BoolQuery()
|
||||||
|
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
|
||||||
|
|
||||||
|
if document_ids_filter:
|
||||||
|
bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter))
|
||||||
|
|
||||||
search_query = tablestore.SearchQuery(
|
search_query = tablestore.SearchQuery(
|
||||||
query=tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value),
|
query=bool_query,
|
||||||
sort=tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]),
|
sort=tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]),
|
||||||
limit=100,
|
limit=top_k,
|
||||||
)
|
)
|
||||||
search_response = self._tablestore_client.search(
|
search_response = self._tablestore_client.search(
|
||||||
table_name=self._table_name,
|
table_name=self._table_name,
|
||||||
@@ -295,7 +327,16 @@ class TableStoreVector(BaseVector):
|
|||||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._to_query_result(search_response)
|
documents = []
|
||||||
|
for search_hit in search_response.search_hits:
|
||||||
|
documents.append(
|
||||||
|
Document(
|
||||||
|
page_content=search_hit.row[1][2][1],
|
||||||
|
vector=json.loads(search_hit.row[1][3][1]),
|
||||||
|
metadata=json.loads(search_hit.row[1][0][1]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
|
||||||
class TableStoreVectorFactory(AbstractVectorFactory):
|
class TableStoreVectorFactory(AbstractVectorFactory):
|
||||||
|
@@ -1,4 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import tablestore
|
||||||
|
|
||||||
from core.rag.datasource.vdb.tablestore.tablestore_vector import (
|
from core.rag.datasource.vdb.tablestore.tablestore_vector import (
|
||||||
TableStoreConfig,
|
TableStoreConfig,
|
||||||
@@ -6,6 +9,8 @@ from core.rag.datasource.vdb.tablestore.tablestore_vector import (
|
|||||||
)
|
)
|
||||||
from tests.integration_tests.vdb.test_vector_store import (
|
from tests.integration_tests.vdb.test_vector_store import (
|
||||||
AbstractVectorTest,
|
AbstractVectorTest,
|
||||||
|
get_example_document,
|
||||||
|
get_example_text,
|
||||||
setup_mock_redis,
|
setup_mock_redis,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,6 +34,49 @@ class TableStoreVectorTest(AbstractVectorTest):
|
|||||||
assert len(ids) == 1
|
assert len(ids) == 1
|
||||||
assert ids[0] == self.example_doc_id
|
assert ids[0] == self.example_doc_id
|
||||||
|
|
||||||
|
def create_vector(self):
|
||||||
|
self.vector.create(
|
||||||
|
texts=[get_example_document(doc_id=self.example_doc_id)],
|
||||||
|
embeddings=[self.example_embedding],
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
search_response = self.vector._tablestore_client.search(
|
||||||
|
table_name=self.vector._table_name,
|
||||||
|
index_name=self.vector._index_name,
|
||||||
|
search_query=tablestore.SearchQuery(query=tablestore.MatchAllQuery(), get_total_count=True, limit=0),
|
||||||
|
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||||
|
)
|
||||||
|
if search_response.total_count == 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
def search_by_vector(self):
|
||||||
|
super().search_by_vector()
|
||||||
|
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[self.example_doc_id])
|
||||||
|
assert len(docs) == 1
|
||||||
|
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||||
|
assert docs[0].metadata["score"] > 0
|
||||||
|
|
||||||
|
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[str(uuid.uuid4())])
|
||||||
|
assert len(docs) == 0
|
||||||
|
|
||||||
|
def search_by_full_text(self):
|
||||||
|
super().search_by_full_text()
|
||||||
|
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id])
|
||||||
|
assert len(docs) == 1
|
||||||
|
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||||
|
assert not hasattr(docs[0], "score")
|
||||||
|
|
||||||
|
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())])
|
||||||
|
assert len(docs) == 0
|
||||||
|
|
||||||
|
def run_all_tests(self):
|
||||||
|
try:
|
||||||
|
self.vector.delete()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return super().run_all_tests()
|
||||||
|
|
||||||
|
|
||||||
def test_tablestore_vector(setup_mock_redis):
|
def test_tablestore_vector(setup_mock_redis):
|
||||||
TableStoreVectorTest().run_all_tests()
|
TableStoreVectorTest().run_all_tests()
|
||||||
|
Reference in New Issue
Block a user