diff --git a/.gitignore b/.gitignore index 5c68d89a4..30432c430 100644 --- a/.gitignore +++ b/.gitignore @@ -197,6 +197,8 @@ sdks/python-client/dify_client.egg-info !.vscode/README.md pyrightconfig.json api/.vscode +# vscode Code History Extension +.history .idea/ diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 9741dd8b1..fcf3a6d12 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -331,6 +331,12 @@ class QdrantVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from qdrant_client.http import models + score_threshold = float(kwargs.get("score_threshold") or 0.0) + if score_threshold >= 1: + # return empty list because some versions of qdrant may response with 400 bad request, + # and at the same time, the score_threshold with value 1 may be valid for other vector stores + return [] + filter = models.Filter( must=[ models.FieldCondition( @@ -355,7 +361,7 @@ class QdrantVector(BaseVector): limit=kwargs.get("top_k", 4), with_payload=True, with_vectors=True, - score_threshold=float(kwargs.get("score_threshold") or 0.0), + score_threshold=score_threshold, ) docs = [] for result in results: @@ -363,7 +369,6 @@ class QdrantVector(BaseVector): continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold - score_threshold = float(kwargs.get("score_threshold") or 0.0) if result.score > score_threshold: metadata["score"] = result.score doc = Document( diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py index 61d9a9e71..fe0e03f7b 100644 --- a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py +++ b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py @@ -1,4 +1,5 @@ from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector +from core.rag.models.document import Document from tests.integration_tests.vdb.test_vector_store import ( AbstractVectorTest, setup_mock_redis, @@ -18,6 +19,14 @@ class QdrantVectorTest(AbstractVectorTest): ), ) + def search_by_vector(self): + super().search_by_vector() + # only test for qdrant, may not work on other vector stores + hits_by_vector: list[Document] = self.vector.search_by_vector( + query_vector=self.example_embedding, score_threshold=1 + ) + assert len(hits_by_vector) == 0 + def test_qdrant_vector(setup_mock_redis): QdrantVectorTest().run_all_tests()