feat: support tencent vector db (#3568)

This commit is contained in:
quicksand
2024-06-14 19:25:17 +08:00
committed by GitHub
parent 9ed21737d5
commit 4080f7b8ad
16 changed files with 481 additions and 5 deletions

View File

@@ -0,0 +1,132 @@
import os
from typing import Optional
import pytest
from _pytest.monkeypatch import MonkeyPatch
from requests.adapters import HTTPAdapter
from tcvectordb import VectorDBClient
from tcvectordb.model.database import Collection, Database
from tcvectordb.model.document import Document, Filter
from tcvectordb.model.enum import ReadConsistency
from tcvectordb.model.index import Index
from xinference_client.types import Embedding
class MockTcvectordbClass:
def VectorDBClient(self, url=None, username='', key='',
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
timeout=5,
adapter: HTTPAdapter = None):
self._conn = None
self._read_consistency = read_consistency
def list_databases(self) -> list[Database]:
return [
Database(
conn=self._conn,
read_consistency=self._read_consistency,
name='dify',
)]
def list_collections(self, timeout: Optional[float] = None) -> list[Collection]:
return []
def drop_collection(self, name: str, timeout: Optional[float] = None):
return {
"code": 0,
"msg": "operation success"
}
def create_collection(
self,
name: str,
shard: int,
replicas: int,
description: str,
index: Index,
embedding: Embedding = None,
timeout: float = None,
) -> Collection:
return Collection(self, name, shard, replicas, description, index, embedding=embedding,
read_consistency=self._read_consistency, timeout=timeout)
def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
collection = Collection(
self,
name,
shard=1,
replicas=2,
description=name,
timeout=timeout
)
return collection
def collection_upsert(
self,
documents: list[Document],
timeout: Optional[float] = None,
build_index: bool = True,
**kwargs
):
return {
"code": 0,
"msg": "operation success"
}
def collection_search(
self,
vectors: list[list[float]],
filter: Filter = None,
params=None,
retrieve_vector: bool = False,
limit: int = 10,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
) -> list[list[dict]]:
return [[{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]]
def collection_query(
self,
document_ids: Optional[list] = None,
retrieve_vector: bool = False,
limit: Optional[int] = None,
offset: Optional[int] = None,
filter: Optional[Filter] = None,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
) -> list[dict]:
return [{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]
def collection_delete(
self,
document_ids: list[str] = None,
filter: Filter = None,
timeout: float = None,
):
return {
"code": 0,
"msg": "operation success"
}
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
@pytest.fixture
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(VectorDBClient, '__init__', MockTcvectordbClass.VectorDBClient)
monkeypatch.setattr(VectorDBClient, 'list_databases', MockTcvectordbClass.list_databases)
monkeypatch.setattr(Database, 'collection', MockTcvectordbClass.describe_collection)
monkeypatch.setattr(Database, 'list_collections', MockTcvectordbClass.list_collections)
monkeypatch.setattr(Database, 'drop_collection', MockTcvectordbClass.drop_collection)
monkeypatch.setattr(Database, 'create_collection', MockTcvectordbClass.create_collection)
monkeypatch.setattr(Collection, 'upsert', MockTcvectordbClass.collection_upsert)
monkeypatch.setattr(Collection, 'search', MockTcvectordbClass.collection_search)
monkeypatch.setattr(Collection, 'query', MockTcvectordbClass.collection_query)
monkeypatch.setattr(Collection, 'delete', MockTcvectordbClass.collection_delete)
yield
if MOCK:
monkeypatch.undo()

View File

@@ -0,0 +1,35 @@
from unittest.mock import MagicMock
from core.rag.datasource.vdb.tencent.tencent_vector import TencentConfig, TencentVector
from tests.integration_tests.vdb.__mock.tcvectordb import setup_tcvectordb_mock
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
mock_client = MagicMock()
mock_client.list_databases.return_value = [{"name": "test"}]
class TencentVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = TencentVector("dify", TencentConfig(
url="http://127.0.0.1",
api_key="dify",
timeout=30,
username="dify",
database="dify",
shard=1,
replicas=2,
))
def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1
def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def test_tencent_vector(setup_mock_redis,setup_tcvectordb_mock):
TencentVectorTest().run_all_tests()