Fix/vdb index issue (#1776)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2023-12-18 21:33:54 +08:00
committed by GitHub
parent 4f0e272549
commit 2fd56cb01c
7 changed files with 18 additions and 10 deletions

View File

@@ -371,7 +371,8 @@ class Completion:
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
embeddings=embeddings,
attributes=['doc_id', 'annotation_id', 'app_id']
)
documents = vector_index.search(

View File

@@ -100,7 +100,6 @@ class MilvusVectorIndex(BaseVectorIndex):
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
return MilvusVectorStore(
collection_name=self.get_index_name(self.dataset),

View File

@@ -9,12 +9,17 @@ from models.dataset import Dataset, Document
class VectorIndex:
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings,
attributes: list = None):
if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self._dataset = dataset
self._embeddings = embeddings
self._vector_index = self._init_vector_index(dataset, config, embeddings)
self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes)
self._attributes = attributes
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings,
attributes: list) -> BaseVectorIndex:
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
@@ -33,7 +38,8 @@ class VectorIndex:
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
embeddings=embeddings
embeddings=embeddings,
attributes=attributes
)
elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig

View File

@@ -27,9 +27,10 @@ class WeaviateConfig(BaseModel):
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
@@ -111,7 +112,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
attributes = self._attributes
if self._is_origin():
attributes = ['doc_id']