diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index 552068c99..55326fd60 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -118,10 +118,21 @@ class TableStoreVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) - return self._search_by_vector(query_vector, top_k) + document_ids_filter = kwargs.get("document_ids_filter") + filtered_list = None + if document_ids_filter: + filtered_list = ["document_id=" + item for item in document_ids_filter] + score_threshold = float(kwargs.get("score_threshold") or 0.0) + return self._search_by_vector(query_vector, filtered_list, top_k, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - return self._search_by_full_text(query) + top_k = kwargs.get("top_k", 4) + document_ids_filter = kwargs.get("document_ids_filter") + filtered_list = None + if document_ids_filter: + filtered_list = ["document_id=" + item for item in document_ids_filter] + + return self._search_by_full_text(query, filtered_list, top_k) def delete(self) -> None: self._delete_table_if_exist() @@ -230,32 +241,51 @@ class TableStoreVector(BaseVector): primary_key = [("id", id)] row = tablestore.Row(primary_key) self._tablestore_client.delete_row(self._table_name, row, None) - logging.info("Tablestore delete row successfully. id:%s", id) def _search_by_metadata(self, key: str, value: str) -> list[str]: query = tablestore.SearchQuery( tablestore.TermQuery(self._tags_field, str(key) + "=" + str(value)), - limit=100, + limit=1000, get_total_count=False, ) + rows: list[str] = [] + next_token = None + while True: + if next_token is not None: + query.next_token = next_token - search_response = self._tablestore_client.search( - table_name=self._table_name, - index_name=self._index_name, - search_query=query, - columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), - ) + search_response = self._tablestore_client.search( + table_name=self._table_name, + index_name=self._index_name, + search_query=query, + columns_to_get=tablestore.ColumnsToGet( + column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED + ), + ) - return [row[0][0][1] for row in search_response.rows] + if search_response is not None: + rows.extend([row[0][0][1] for row in search_response.rows]) - def _search_by_vector(self, query_vector: list[float], top_k: int) -> list[Document]: - ots_query = tablestore.KnnVectorQuery( + if search_response is None or search_response.next_token == b"": + break + else: + next_token = search_response.next_token + + return rows + + def _search_by_vector( + self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float + ) -> list[Document]: + knn_vector_query = tablestore.KnnVectorQuery( field_name=Field.VECTOR.value, top_k=top_k, float32_query_vector=query_vector, ) + if document_ids_filter: + knn_vector_query.filter = tablestore.TermsQuery(self._tags_field, document_ids_filter) + sort = tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]) - search_query = tablestore.SearchQuery(ots_query, limit=top_k, get_total_count=False, sort=sort) + search_query = tablestore.SearchQuery(knn_vector_query, limit=top_k, get_total_count=False, sort=sort) search_response = self._tablestore_client.search( table_name=self._table_name, @@ -263,30 +293,32 @@ class TableStoreVector(BaseVector): search_query=search_query, columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), ) - logging.info( - "Tablestore search successfully. request_id:%s", - search_response.request_id, - ) - return self._to_query_result(search_response) - - def _to_query_result(self, search_response: tablestore.SearchResponse) -> list[Document]: documents = [] - for row in search_response.rows: - documents.append( - Document( - page_content=row[1][2][1], - vector=json.loads(row[1][3][1]), - metadata=json.loads(row[1][0][1]), + for search_hit in search_response.search_hits: + if search_hit.score > score_threshold: + metadata = json.loads(search_hit.row[1][0][1]) + metadata["score"] = search_hit.score + documents.append( + Document( + page_content=search_hit.row[1][2][1], + vector=json.loads(search_hit.row[1][3][1]), + metadata=metadata, + ) ) - ) - + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents - def _search_by_full_text(self, query: str) -> list[Document]: + def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]: + bool_query = tablestore.BoolQuery() + bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value)) + + if document_ids_filter: + bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter)) + search_query = tablestore.SearchQuery( - query=tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value), + query=bool_query, sort=tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]), - limit=100, + limit=top_k, ) search_response = self._tablestore_client.search( table_name=self._table_name, @@ -295,7 +327,16 @@ class TableStoreVector(BaseVector): columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), ) - return self._to_query_result(search_response) + documents = [] + for search_hit in search_response.search_hits: + documents.append( + Document( + page_content=search_hit.row[1][2][1], + vector=json.loads(search_hit.row[1][3][1]), + metadata=json.loads(search_hit.row[1][0][1]), + ) + ) + return documents class TableStoreVectorFactory(AbstractVectorFactory): diff --git a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py b/api/tests/integration_tests/vdb/tablestore/test_tablestore.py index da890d0b7..da549af1b 100644 --- a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py +++ b/api/tests/integration_tests/vdb/tablestore/test_tablestore.py @@ -1,4 +1,7 @@ import os +import uuid + +import tablestore from core.rag.datasource.vdb.tablestore.tablestore_vector import ( TableStoreConfig, @@ -6,6 +9,8 @@ from core.rag.datasource.vdb.tablestore.tablestore_vector import ( ) from tests.integration_tests.vdb.test_vector_store import ( AbstractVectorTest, + get_example_document, + get_example_text, setup_mock_redis, ) @@ -29,6 +34,49 @@ class TableStoreVectorTest(AbstractVectorTest): assert len(ids) == 1 assert ids[0] == self.example_doc_id + def create_vector(self): + self.vector.create( + texts=[get_example_document(doc_id=self.example_doc_id)], + embeddings=[self.example_embedding], + ) + while True: + search_response = self.vector._tablestore_client.search( + table_name=self.vector._table_name, + index_name=self.vector._index_name, + search_query=tablestore.SearchQuery(query=tablestore.MatchAllQuery(), get_total_count=True, limit=0), + columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), + ) + if search_response.total_count == 1: + break + + def search_by_vector(self): + super().search_by_vector() + docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[self.example_doc_id]) + assert len(docs) == 1 + assert docs[0].metadata["doc_id"] == self.example_doc_id + assert docs[0].metadata["score"] > 0 + + docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[str(uuid.uuid4())]) + assert len(docs) == 0 + + def search_by_full_text(self): + super().search_by_full_text() + docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id]) + assert len(docs) == 1 + assert docs[0].metadata["doc_id"] == self.example_doc_id + assert not hasattr(docs[0], "score") + + docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())]) + assert len(docs) == 0 + + def run_all_tests(self): + try: + self.vector.delete() + except Exception: + pass + + return super().run_all_tests() + def test_tablestore_vector(setup_mock_redis): TableStoreVectorTest().run_all_tests()