Add Volcengine VikingDB as new vector provider (#9287)

This commit is contained in:
ice yao
2024-10-13 21:26:05 +08:00
committed by GitHub
parent 1ec83e4969
commit d15ba3939d
15 changed files with 627 additions and 3 deletions

View File

@@ -0,0 +1,215 @@
import os
from typing import Union
from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from volcengine.viking_db import (
Collection,
Data,
DistanceType,
Field,
FieldType,
Index,
IndexType,
QuantType,
VectorIndexParams,
VikingDBService,
)
from core.rag.datasource.vdb.field import Field as vdb_Field
class MockVikingDBClass:
def __init__(
self,
host="api-vikingdb.volces.com",
region="cn-north-1",
ak="",
sk="",
scheme="http",
connection_timeout=30,
socket_timeout=30,
proxy=None,
):
self._viking_db_service = MagicMock()
self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}')
def get_collection(self, collection_name) -> Collection:
return Collection(
collection_name=collection_name,
description="Collection For Dify",
viking_db_service=self._viking_db_service,
primary_key=vdb_Field.PRIMARY_KEY.value,
fields=[
Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768),
],
indexes=[
Index(
collection_name=collection_name,
index_name=f"{collection_name}_idx",
vector_index=VectorIndexParams(
distance=DistanceType.L2,
index_type=IndexType.HNSW,
quant=QuantType.Float,
),
scalar_index=None,
stat=None,
viking_db_service=self._viking_db_service,
)
],
)
def drop_collection(self, collection_name):
assert collection_name != ""
def create_collection(self, collection_name, fields, description="") -> Collection:
return Collection(
collection_name=collection_name,
description=description,
primary_key=vdb_Field.PRIMARY_KEY.value,
viking_db_service=self._viking_db_service,
fields=fields,
)
def get_index(self, collection_name, index_name) -> Index:
return Index(
collection_name=collection_name,
index_name=index_name,
viking_db_service=self._viking_db_service,
stat=None,
scalar_index=None,
vector_index=VectorIndexParams(
distance=DistanceType.L2,
index_type=IndexType.HNSW,
quant=QuantType.Float,
),
)
def create_index(
self,
collection_name,
index_name,
vector_index=None,
cpu_quota=2,
description="",
partition_by="",
scalar_index=None,
shard_count=None,
shard_policy=None,
):
return Index(
collection_name=collection_name,
index_name=index_name,
vector_index=vector_index,
cpu_quota=cpu_quota,
description=description,
partition_by=partition_by,
scalar_index=scalar_index,
shard_count=shard_count,
shard_policy=shard_policy,
viking_db_service=self._viking_db_service,
stat=None,
)
def drop_index(self, collection_name, index_name):
assert collection_name != ""
assert index_name != ""
def upsert_data(self, data: Union[Data, list[Data]]):
assert data is not None
def fetch_data(self, id: Union[str, list[str], int, list[int]]):
return Data(
fields={
vdb_Field.GROUP_KEY.value: "test_group",
vdb_Field.METADATA_KEY.value: "{}",
vdb_Field.CONTENT_KEY.value: "content",
vdb_Field.PRIMARY_KEY.value: id,
vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
},
id=id,
)
def delete_data(self, id: Union[str, list[str], int, list[int]]):
assert id is not None
def search_by_vector(
self,
vector,
sparse_vectors=None,
filter=None,
limit=10,
output_fields=None,
partition="default",
dense_weight=None,
) -> list[Data]:
return [
Data(
fields={
vdb_Field.GROUP_KEY.value: "test_group",
vdb_Field.METADATA_KEY.value: '\
{"source": "/var/folders/ml/xxx/xxx.txt", \
"document_id": "test_document_id", \
"dataset_id": "test_dataset_id", \
"doc_id": "test_id", \
"doc_hash": "test_hash"}',
vdb_Field.CONTENT_KEY.value: "content",
vdb_Field.PRIMARY_KEY.value: "test_id",
vdb_Field.VECTOR.value: vector,
},
id="test_id",
score=0.10,
)
]
def search(
self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None
) -> list[Data]:
return [
Data(
fields={
vdb_Field.GROUP_KEY.value: "test_group",
vdb_Field.METADATA_KEY.value: '\
{"source": "/var/folders/ml/xxx/xxx.txt", \
"document_id": "test_document_id", \
"dataset_id": "test_dataset_id", \
"doc_id": "test_id", \
"doc_hash": "test_hash"}',
vdb_Field.CONTENT_KEY.value: "content",
vdb_Field.PRIMARY_KEY.value: "test_id",
vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
},
id="test_id",
score=0.10,
)
]
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_vikingdb_mock(monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__)
monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection)
monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection)
monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection)
monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index)
monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index)
monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index)
monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data)
monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data)
monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data)
monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector)
monkeypatch.setattr(Index, "search", MockVikingDBClass.search)
yield
if MOCK:
monkeypatch.undo()

View File

@@ -0,0 +1,37 @@
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBConfig, VikingDBVector
from tests.integration_tests.vdb.__mock.vikingdb import setup_vikingdb_mock
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
class VikingDBVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = VikingDBVector(
"test_collection",
"test_group",
config=VikingDBConfig(
access_key="test_access_key",
host="test_host",
region="test_region",
scheme="test_scheme",
secret_key="test_secret_key",
connection_timeout=30,
socket_timeout=30,
),
)
def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1
def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value="test_document_id")
assert len(ids) > 0
def test_vikingdb_vector(setup_mock_redis, setup_vikingdb_mock):
VikingDBVectorTest().run_all_tests()