feat: support tencent vector db (#3568)
This commit is contained in:
0
api/tests/integration_tests/vdb/__mock/__init__.py
Normal file
0
api/tests/integration_tests/vdb/__mock/__init__.py
Normal file
132
api/tests/integration_tests/vdb/__mock/tcvectordb.py
Normal file
132
api/tests/integration_tests/vdb/__mock/tcvectordb.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from requests.adapters import HTTPAdapter
|
||||
from tcvectordb import VectorDBClient
|
||||
from tcvectordb.model.database import Collection, Database
|
||||
from tcvectordb.model.document import Document, Filter
|
||||
from tcvectordb.model.enum import ReadConsistency
|
||||
from tcvectordb.model.index import Index
|
||||
from xinference_client.types import Embedding
|
||||
|
||||
|
||||
class MockTcvectordbClass:
|
||||
|
||||
def VectorDBClient(self, url=None, username='', key='',
|
||||
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
|
||||
timeout=5,
|
||||
adapter: HTTPAdapter = None):
|
||||
self._conn = None
|
||||
self._read_consistency = read_consistency
|
||||
|
||||
def list_databases(self) -> list[Database]:
|
||||
return [
|
||||
Database(
|
||||
conn=self._conn,
|
||||
read_consistency=self._read_consistency,
|
||||
name='dify',
|
||||
)]
|
||||
|
||||
def list_collections(self, timeout: Optional[float] = None) -> list[Collection]:
|
||||
return []
|
||||
|
||||
def drop_collection(self, name: str, timeout: Optional[float] = None):
|
||||
return {
|
||||
"code": 0,
|
||||
"msg": "operation success"
|
||||
}
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
shard: int,
|
||||
replicas: int,
|
||||
description: str,
|
||||
index: Index,
|
||||
embedding: Embedding = None,
|
||||
timeout: float = None,
|
||||
) -> Collection:
|
||||
return Collection(self, name, shard, replicas, description, index, embedding=embedding,
|
||||
read_consistency=self._read_consistency, timeout=timeout)
|
||||
|
||||
def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
|
||||
collection = Collection(
|
||||
self,
|
||||
name,
|
||||
shard=1,
|
||||
replicas=2,
|
||||
description=name,
|
||||
timeout=timeout
|
||||
)
|
||||
return collection
|
||||
|
||||
def collection_upsert(
|
||||
self,
|
||||
documents: list[Document],
|
||||
timeout: Optional[float] = None,
|
||||
build_index: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
return {
|
||||
"code": 0,
|
||||
"msg": "operation success"
|
||||
}
|
||||
|
||||
def collection_search(
|
||||
self,
|
||||
vectors: list[list[float]],
|
||||
filter: Filter = None,
|
||||
params=None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: int = 10,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> list[list[dict]]:
|
||||
return [[{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]]
|
||||
|
||||
def collection_query(
|
||||
self,
|
||||
document_ids: Optional[list] = None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
filter: Optional[Filter] = None,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> list[dict]:
|
||||
return [{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]
|
||||
|
||||
def collection_delete(
|
||||
self,
|
||||
document_ids: list[str] = None,
|
||||
filter: Filter = None,
|
||||
timeout: float = None,
|
||||
):
|
||||
return {
|
||||
"code": 0,
|
||||
"msg": "operation success"
|
||||
}
|
||||
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(VectorDBClient, '__init__', MockTcvectordbClass.VectorDBClient)
|
||||
monkeypatch.setattr(VectorDBClient, 'list_databases', MockTcvectordbClass.list_databases)
|
||||
monkeypatch.setattr(Database, 'collection', MockTcvectordbClass.describe_collection)
|
||||
monkeypatch.setattr(Database, 'list_collections', MockTcvectordbClass.list_collections)
|
||||
monkeypatch.setattr(Database, 'drop_collection', MockTcvectordbClass.drop_collection)
|
||||
monkeypatch.setattr(Database, 'create_collection', MockTcvectordbClass.create_collection)
|
||||
monkeypatch.setattr(Collection, 'upsert', MockTcvectordbClass.collection_upsert)
|
||||
monkeypatch.setattr(Collection, 'search', MockTcvectordbClass.collection_search)
|
||||
monkeypatch.setattr(Collection, 'query', MockTcvectordbClass.collection_query)
|
||||
monkeypatch.setattr(Collection, 'delete', MockTcvectordbClass.collection_delete)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
35
api/tests/integration_tests/vdb/tcvectordb/test_tencent.py
Normal file
35
api/tests/integration_tests/vdb/tcvectordb/test_tencent.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.rag.datasource.vdb.tencent.tencent_vector import TencentConfig, TencentVector
|
||||
from tests.integration_tests.vdb.__mock.tcvectordb import setup_tcvectordb_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.list_databases.return_value = [{"name": "test"}]
|
||||
|
||||
class TencentVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = TencentVector("dify", TencentConfig(
|
||||
url="http://127.0.0.1",
|
||||
api_key="dify",
|
||||
timeout=30,
|
||||
username="dify",
|
||||
database="dify",
|
||||
shard=1,
|
||||
replicas=2,
|
||||
))
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def test_tencent_vector(setup_mock_redis,setup_tcvectordb_mock):
|
||||
TencentVectorTest().run_all_tests()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user