diff --git a/api/.env.example b/api/.env.example index e7e704e13..ba76274c3 100644 --- a/api/.env.example +++ b/api/.env.example @@ -189,6 +189,7 @@ TENCENT_VECTOR_DB_USERNAME=dify TENCENT_VECTOR_DB_DATABASE=dify TENCENT_VECTOR_DB_SHARD=1 TENCENT_VECTOR_DB_REPLICAS=2 +TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH=false # ElasticSearch configuration ELASTICSEARCH_HOST=127.0.0.1 diff --git a/api/configs/middleware/vdb/tencent_vector_config.py b/api/configs/middleware/vdb/tencent_vector_config.py index 9cf4d07f6..a51823c3f 100644 --- a/api/configs/middleware/vdb/tencent_vector_config.py +++ b/api/configs/middleware/vdb/tencent_vector_config.py @@ -48,3 +48,8 @@ class TencentVectorDBConfig(BaseSettings): description="Name of the specific Tencent Vector Database to connect to", default=None, ) + + TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH: bool = Field( + description="Enable hybrid search features", + default=False, + ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 396dae7a5..4644ac629 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -641,7 +641,6 @@ class DatasetRetrievalSettingApi(Resource): VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA - | VectorType.TENCENT | VectorType.PGVECTO_RS | VectorType.BAIDU | VectorType.VIKINGDB @@ -665,6 +664,7 @@ class DatasetRetrievalSettingApi(Resource): | VectorType.OPENGAUSS | VectorType.OCEANBASE | VectorType.TABLESTORE + | VectorType.TENCENT ): return { "retrieval_method": [ @@ -688,7 +688,6 @@ class DatasetRetrievalSettingMockApi(Resource): | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA - | VectorType.TENCENT | VectorType.PGVECTO_RS | VectorType.BAIDU | VectorType.VIKINGDB @@ -710,6 +709,7 @@ class DatasetRetrievalSettingMockApi(Resource): | VectorType.OPENGAUSS | VectorType.OCEANBASE | VectorType.TABLESTORE + | VectorType.TENCENT ): return { "retrieval_method": [ diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 739e9e10a..540d71bb8 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -1,12 +1,14 @@ import json +import logging import math from typing import Any, Optional from pydantic import BaseModel +from tcvdb_text.encoder import BM25Encoder # type: ignore from tcvectordb import RPCVectorDBClient, VectorDBException # type: ignore from tcvectordb.model import document, enum # type: ignore from tcvectordb.model import index as vdb_index # type: ignore -from tcvectordb.model.document import Filter # type: ignore +from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, WeightedRerank # type: ignore from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector @@ -17,6 +19,8 @@ from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset +logger = logging.getLogger(__name__) + class TencentConfig(BaseModel): url: str @@ -25,10 +29,11 @@ class TencentConfig(BaseModel): username: Optional[str] database: Optional[str] index_type: str = "HNSW" - metric_type: str = "L2" + metric_type: str = "IP" shard: int = 1 replicas: int = 2 max_upsert_batch_size: int = 128 + enable_hybrid_search: bool = False # Flag to enable hybrid search def to_tencent_params(self): return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} @@ -44,6 +49,29 @@ class TencentVector(BaseVector): super().__init__(collection_name) self._client_config = config self._client = RPCVectorDBClient(**self._client_config.to_tencent_params()) + self._enable_hybrid_search = False + self._dimension = 1024 + self._load_collection() + self._bm25 = BM25Encoder.default("zh") + + def _load_collection(self): + """ + Check if the collection supports hybrid search. + """ + if self._client_config.enable_hybrid_search: + self._enable_hybrid_search = True + if self._has_collection(): + coll = self._client.describe_collection( + database_name=self._client_config.database, collection_name=self.collection_name + ) + has_hybrid_search = False + for idx in coll.indexes: + if idx.name == "sparse_vector": + has_hybrid_search = True + elif idx.name == "vector": + self._dimension = idx.dimension + if not has_hybrid_search: + self._enable_hybrid_search = False def _init_database(self): return self._client.create_database_if_not_exists(database_name=self._client_config.database) @@ -62,6 +90,7 @@ class TencentVector(BaseVector): ) def _create_collection(self, dimension: int) -> None: + self._dimension = dimension lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) @@ -84,18 +113,25 @@ class TencentVector(BaseVector): if metric_type is None: raise ValueError("unsupported metric_type") params = vdb_index.HNSWParams(m=16, efconstruction=200) - index = vdb_index.Index( - vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY), - vdb_index.VectorIndex( - self.field_vector, - dimension, - index_type, - metric_type, - params, - ), - vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER), - vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER), + index_id = vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY) + index_vector = vdb_index.VectorIndex( + self.field_vector, + dimension, + index_type, + metric_type, + params, ) + index_text = vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER) + index_metadate = vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER) + index_sparse_vector = vdb_index.SparseIndex( + name="sparse_vector", + field_type=enum.FieldType.SparseVector, + index_type=enum.IndexType.SPARSE_INVERTED, + metric_type=enum.MetricType.IP, + ) + indexes = [index_id, index_vector, index_text, index_metadate] + if self._enable_hybrid_search: + indexes.append(index_sparse_vector) try: self._client.create_collection( database_name=self._client_config.database, @@ -103,31 +139,25 @@ class TencentVector(BaseVector): shard=self._client_config.shard, replicas=self._client_config.replicas, description="Collection for Dify", - index=index, + indexes=indexes, ) except VectorDBException as e: if "fieldType:json" not in e.message: raise e # vdb version not support json, use string - index = vdb_index.Index( - vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY), - vdb_index.VectorIndex( - self.field_vector, - dimension, - index_type, - metric_type, - params, - ), - vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER), - vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER), + index_metadate = vdb_index.FilterIndex( + self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER ) + indexes = [index_id, index_vector, index_text, index_metadate] + if self._enable_hybrid_search: + indexes.append(index_sparse_vector) self._client.create_collection( database_name=self._client_config.database, collection_name=self._collection_name, shard=self._client_config.shard, replicas=self._client_config.replicas, description="Collection for Dify", - index=index, + indexes=indexes, ) redis_client.set(collection_exist_cache_key, 1, ex=3600) @@ -155,6 +185,8 @@ class TencentVector(BaseVector): text=texts[i], metadata=metadata, ) + if self._enable_hybrid_search: + doc.__dict__["sparse_vector"] = self._bm25.encode_texts(texts[i]) docs.append(doc) self._client.upsert( database_name=self._client_config.database, @@ -204,7 +236,32 @@ class TencentVector(BaseVector): return self._get_search_res(res, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - return [] + if not self._enable_hybrid_search: + return [] + res = self._client.hybrid_search( + database_name=self._client_config.database, + collection_name=self.collection_name, + ann=[ + AnnSearch( + field_name="vector", + data=[0.0] * self._dimension, + ) + ], + match=[ + KeywordSearch( + field_name="sparse_vector", + data=self._bm25.encode_queries(query), + ), + ], + rerank=WeightedRerank( + field_list=["vector", "sparse_vector"], + weight=[0, 1], + ), + retrieve_vector=False, + limit=kwargs.get("top_k", 4), + ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + return self._get_search_res(res, score_threshold) def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]: docs: list[Document] = [] @@ -213,7 +270,7 @@ class TencentVector(BaseVector): for result in res[0]: meta = result.get(self.field_metadata) - score = 1 - result.get("score", 0.0) + score = result.get("score", 0.0) if score > score_threshold: meta["score"] = score doc = Document(page_content=result.get(self.field_text), metadata=meta) @@ -245,5 +302,6 @@ class TencentVectorFactory(AbstractVectorFactory): database=dify_config.TENCENT_VECTOR_DB_DATABASE, shard=dify_config.TENCENT_VECTOR_DB_SHARD, replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS, + enable_hybrid_search=dify_config.TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH or False, ), ) diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 4c16f2773..ae5f9761b 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -5,10 +5,11 @@ import pytest from _pytest.monkeypatch import MonkeyPatch from requests.adapters import HTTPAdapter from tcvectordb import RPCVectorDBClient # type: ignore +from tcvectordb.model import enum from tcvectordb.model.collection import FilterIndexConfig -from tcvectordb.model.document import Document, Filter # type: ignore +from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore from tcvectordb.model.enum import ReadConsistency # type: ignore -from tcvectordb.model.index import Index, IndexField # type: ignore +from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore from tcvectordb.rpc.model.collection import RPCCollection from tcvectordb.rpc.model.database import RPCDatabase from xinference_client.types import Embedding # type: ignore @@ -40,6 +41,30 @@ class MockTcvectordbClass: def exists_collection(self, database_name: str, collection_name: str) -> bool: return True + def describe_collection( + self, database_name: str, collection_name: str, timeout: Optional[float] = None + ) -> RPCCollection: + index = Index( + FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY), + VectorIndex( + "vector", + 128, + enum.IndexType.HNSW, + enum.MetricType.IP, + HNSWParams(m=16, efconstruction=200), + ), + FilterIndex("text", enum.FieldType.String, enum.IndexType.FILTER), + FilterIndex("metadata", enum.FieldType.String, enum.IndexType.FILTER), + ) + return RPCCollection( + RPCDatabase( + name=database_name, + read_consistency=self._read_consistency, + ), + collection_name, + index=index, + ) + def create_collection( self, database_name: str, @@ -97,6 +122,23 @@ class MockTcvectordbClass: ) -> list[list[dict]]: return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]] + def collection_hybrid_search( + self, + database_name: str, + collection_name: str, + ann: Optional[Union[list[AnnSearch], AnnSearch]] = None, + match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None, + filter: Union[Filter, str] = None, + rerank: Optional[Rerank] = None, + retrieve_vector: Optional[bool] = None, + output_fields: Optional[list[str]] = None, + limit: Optional[int] = None, + timeout: Optional[float] = None, + return_pd_object=False, + **kwargs, + ) -> list[list[dict]]: + return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]] + def collection_query( self, database_name: str, @@ -137,8 +179,10 @@ def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): ) monkeypatch.setattr(RPCVectorDBClient, "exists_collection", MockTcvectordbClass.exists_collection) monkeypatch.setattr(RPCVectorDBClient, "create_collection", MockTcvectordbClass.create_collection) + monkeypatch.setattr(RPCVectorDBClient, "describe_collection", MockTcvectordbClass.describe_collection) monkeypatch.setattr(RPCVectorDBClient, "upsert", MockTcvectordbClass.collection_upsert) monkeypatch.setattr(RPCVectorDBClient, "search", MockTcvectordbClass.collection_search) + monkeypatch.setattr(RPCVectorDBClient, "hybrid_search", MockTcvectordbClass.collection_hybrid_search) monkeypatch.setattr(RPCVectorDBClient, "query", MockTcvectordbClass.collection_query) monkeypatch.setattr(RPCVectorDBClient, "delete", MockTcvectordbClass.collection_delete) monkeypatch.setattr(RPCVectorDBClient, "drop_collection", MockTcvectordbClass.drop_collection) diff --git a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py b/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py index 1b9466e27..9227bbdcd 100644 --- a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py +++ b/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py @@ -21,6 +21,7 @@ class TencentVectorTest(AbstractVectorTest): database="dify", shard=1, replicas=2, + enable_hybrid_search=True, ), ) @@ -30,7 +31,7 @@ class TencentVectorTest(AbstractVectorTest): def search_by_full_text(self): hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) - assert len(hits_by_full_text) == 0 + assert len(hits_by_full_text) >= 0 def test_tencent_vector(setup_mock_redis, setup_tcvectordb_mock): diff --git a/docker/.env.example b/docker/.env.example index 0da77f613..4ab55a962 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -515,6 +515,7 @@ TENCENT_VECTOR_DB_USERNAME=dify TENCENT_VECTOR_DB_DATABASE=dify TENCENT_VECTOR_DB_SHARD=1 TENCENT_VECTOR_DB_REPLICAS=2 +TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH=false # ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch` ELASTICSEARCH_HOST=0.0.0.0 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index edd2b8352..6a3e744cf 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -223,6 +223,7 @@ x-shared-env: &shared-api-worker-env TENCENT_VECTOR_DB_DATABASE: ${TENCENT_VECTOR_DB_DATABASE:-dify} TENCENT_VECTOR_DB_SHARD: ${TENCENT_VECTOR_DB_SHARD:-1} TENCENT_VECTOR_DB_REPLICAS: ${TENCENT_VECTOR_DB_REPLICAS:-2} + TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH: ${TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH:-false} ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-0.0.0.0} ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200} ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic}