Feat/improve vector database logic (#1193)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
@@ -16,6 +16,10 @@ class BaseIndex(ABC):
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
raise NotImplementedError
|
||||
@@ -28,6 +32,10 @@ class BaseIndex(ABC):
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
raise NotImplementedError
|
||||
|
@@ -46,6 +46,32 @@ class KeywordTableIndex(BaseIndex):
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = {}
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self.dataset.id,
|
||||
keyword_table=json.dumps({
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": {}
|
||||
}
|
||||
}, cls=SetEncoder)
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
return self
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
|
||||
@@ -120,6 +146,12 @@ class KeywordTableIndex(BaseIndex):
|
||||
db.session.delete(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
db.session.delete(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
def _save_dataset_keyword_table(self, keyword_table):
|
||||
keyword_table_dict = {
|
||||
'__type__': 'keyword_table',
|
||||
|
@@ -10,7 +10,7 @@ from weaviate import UnexpectedStatusCodeException
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
@@ -110,6 +110,12 @@ class BaseVectorIndex(BaseIndex):
|
||||
for node_id in ids:
|
||||
vector_store.del_text(node_id)
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
@@ -243,3 +249,53 @@ class BaseVectorIndex(BaseIndex):
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
|
||||
logging.info(f"restore dataset in_one,_dataset {dataset.id}")
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
|
||||
logging.info(f"delete original collection: {dataset.id}")
|
||||
|
||||
self.delete()
|
||||
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
@@ -69,6 +69,19 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=collection_name,
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
|
@@ -28,6 +28,7 @@ from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
from qdrant_client.http.models import PayloadSchemaType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
@@ -84,6 +85,7 @@ class Qdrant(VectorStore):
|
||||
|
||||
CONTENT_KEY = "page_content"
|
||||
METADATA_KEY = "metadata"
|
||||
GROUP_KEY = "group_id"
|
||||
VECTOR_NAME = None
|
||||
|
||||
def __init__(
|
||||
@@ -93,9 +95,12 @@ class Qdrant(VectorStore):
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
group_payload_key: str = GROUP_KEY,
|
||||
group_id: str = None,
|
||||
distance_strategy: str = "COSINE",
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
embedding_function: Optional[Callable] = None, # deprecated
|
||||
is_new_collection: bool = False
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
try:
|
||||
@@ -129,7 +134,10 @@ class Qdrant(VectorStore):
|
||||
self.collection_name = collection_name
|
||||
self.content_payload_key = content_payload_key or self.CONTENT_KEY
|
||||
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
|
||||
self.group_payload_key = group_payload_key or self.GROUP_KEY
|
||||
self.vector_name = vector_name or self.VECTOR_NAME
|
||||
self.group_id = group_id
|
||||
self.is_new_collection= is_new_collection
|
||||
|
||||
if embedding_function is not None:
|
||||
warnings.warn(
|
||||
@@ -170,6 +178,8 @@ class Qdrant(VectorStore):
|
||||
batch_size:
|
||||
How many vectors upload per-request.
|
||||
Default: 64
|
||||
group_id:
|
||||
collection group
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
@@ -182,7 +192,11 @@ class Qdrant(VectorStore):
|
||||
collection_name=self.collection_name, points=points, **kwargs
|
||||
)
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
# if is new collection, create payload index on group_id
|
||||
if self.is_new_collection:
|
||||
self.client.create_payload_index(self.collection_name, self.group_payload_key,
|
||||
field_schema=PayloadSchemaType.KEYWORD,
|
||||
field_type=PayloadSchemaType.KEYWORD)
|
||||
return added_ids
|
||||
|
||||
@sync_call_fallback
|
||||
@@ -970,6 +984,8 @@ class Qdrant(VectorStore):
|
||||
distance_func: str = "Cosine",
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
group_payload_key: str = GROUP_KEY,
|
||||
group_id: str = None,
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
batch_size: int = 64,
|
||||
shard_number: Optional[int] = None,
|
||||
@@ -1034,6 +1050,11 @@ class Qdrant(VectorStore):
|
||||
metadata_payload_key:
|
||||
A payload key used to store the metadata of the document.
|
||||
Default: "metadata"
|
||||
group_payload_key:
|
||||
A payload key used to store the content of the document.
|
||||
Default: "group_id"
|
||||
group_id:
|
||||
collection group id
|
||||
vector_name:
|
||||
Name of the vector to be used internally in Qdrant.
|
||||
Default: None
|
||||
@@ -1107,6 +1128,8 @@ class Qdrant(VectorStore):
|
||||
distance_func,
|
||||
content_payload_key,
|
||||
metadata_payload_key,
|
||||
group_payload_key,
|
||||
group_id,
|
||||
vector_name,
|
||||
shard_number,
|
||||
replication_factor,
|
||||
@@ -1321,6 +1344,8 @@ class Qdrant(VectorStore):
|
||||
distance_func: str = "Cosine",
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
group_payload_key: str = GROUP_KEY,
|
||||
group_id: str = None,
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
shard_number: Optional[int] = None,
|
||||
replication_factor: Optional[int] = None,
|
||||
@@ -1350,6 +1375,7 @@ class Qdrant(VectorStore):
|
||||
vector_size = len(partial_embeddings[0])
|
||||
collection_name = collection_name or uuid.uuid4().hex
|
||||
distance_func = distance_func.upper()
|
||||
is_new_collection = False
|
||||
client = qdrant_client.QdrantClient(
|
||||
location=location,
|
||||
url=url,
|
||||
@@ -1454,6 +1480,7 @@ class Qdrant(VectorStore):
|
||||
init_from=init_from,
|
||||
timeout=timeout, # type: ignore[arg-type]
|
||||
)
|
||||
is_new_collection = True
|
||||
qdrant = cls(
|
||||
client=client,
|
||||
collection_name=collection_name,
|
||||
@@ -1462,6 +1489,9 @@ class Qdrant(VectorStore):
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
distance_strategy=distance_func,
|
||||
vector_name=vector_name,
|
||||
group_id=group_id,
|
||||
group_payload_key=group_payload_key,
|
||||
is_new_collection=is_new_collection
|
||||
)
|
||||
return qdrant
|
||||
|
||||
@@ -1516,6 +1546,8 @@ class Qdrant(VectorStore):
|
||||
metadatas: Optional[List[dict]],
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
group_id: str,
|
||||
group_payload_key: str
|
||||
) -> List[dict]:
|
||||
payloads = []
|
||||
for i, text in enumerate(texts):
|
||||
@@ -1529,6 +1561,7 @@ class Qdrant(VectorStore):
|
||||
{
|
||||
content_payload_key: text,
|
||||
metadata_payload_key: metadata,
|
||||
group_payload_key: group_id
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1578,7 +1611,7 @@ class Qdrant(VectorStore):
|
||||
else:
|
||||
out.append(
|
||||
rest.FieldCondition(
|
||||
key=f"{self.metadata_payload_key}.{key}",
|
||||
key=key,
|
||||
match=rest.MatchValue(value=value),
|
||||
)
|
||||
)
|
||||
@@ -1654,6 +1687,7 @@ class Qdrant(VectorStore):
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
batch_size: int = 64,
|
||||
group_id: Optional[str] = None,
|
||||
) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
@@ -1684,6 +1718,8 @@ class Qdrant(VectorStore):
|
||||
batch_metadatas,
|
||||
self.content_payload_key,
|
||||
self.metadata_payload_key,
|
||||
self.group_id,
|
||||
self.group_payload_key
|
||||
),
|
||||
)
|
||||
]
|
||||
|
@@ -6,18 +6,20 @@ from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document, BaseRetriever
|
||||
from langchain.vectorstores import VectorStore
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client.http.models import HnswConfigDiff
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.qdrant_vector_store import QdrantVectorStore
|
||||
from models.dataset import Dataset
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
root_path: Optional[str]
|
||||
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith('path:'):
|
||||
path = self.endpoint.replace('path:', '')
|
||||
@@ -43,16 +45,21 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
return 'qdrant'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
if dataset_collection_binding:
|
||||
return dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
return class_prefix
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
@@ -68,6 +75,27 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
ids=uuids,
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id',
|
||||
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||
max_indexing_threads=0, on_disk=False),
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = QdrantVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
collection_name=collection_name,
|
||||
ids=uuids,
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id',
|
||||
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||
max_indexing_threads=0, on_disk=False),
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
@@ -78,8 +106,6 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
client = qdrant_client.QdrantClient(
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
@@ -88,16 +114,15 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
client=client,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
embeddings=self._embeddings,
|
||||
content_payload_key='page_content'
|
||||
content_payload_key='page_content',
|
||||
group_id=self.dataset.id,
|
||||
group_payload_key='group_id'
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return QdrantVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
@@ -114,9 +139,6 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
))
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
@@ -132,6 +154,22 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=group_id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
|
@@ -91,6 +91,20 @@ class WeaviateVectorIndex(BaseVectorIndex):
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
|
@@ -33,7 +33,6 @@ class DatasetRetrieverTool(BaseTool):
|
||||
return_resource: str
|
||||
retriever_from: str
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
description = dataset.description
|
||||
@@ -94,7 +93,10 @@ class DatasetRetrieverTool(BaseTool):
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': self.k
|
||||
'k': self.k,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
|
@@ -46,6 +46,11 @@ class QdrantVectorStore(Qdrant):
|
||||
|
||||
self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
def delete_group(self):
|
||||
self._reload_if_needed()
|
||||
|
||||
self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
cls,
|
||||
|
Reference in New Issue
Block a user