Initial commit
This commit is contained in:
190
api/core/docstore/dataset_docstore.py
Normal file
190
api/core/docstore/dataset_docstore.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
|
||||
import tiktoken
|
||||
from llama_index.data_structs import Node
|
||||
from llama_index.docstore.types import BaseDocumentStore
|
||||
from llama_index.docstore.utils import json_to_doc
|
||||
from llama_index.schema import BaseDocument
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.llm.token_calculator import TokenCalculator
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
class DatesetDocumentStore(BaseDocumentStore):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
user_id: str,
|
||||
embedding_model_name: str,
|
||||
document_id: Optional[str] = None,
|
||||
):
|
||||
self._dataset = dataset
|
||||
self._user_id = user_id
|
||||
self._embedding_model_name = embedding_model_name
|
||||
self._document_id = document_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "DatesetDocumentStore":
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize to dict."""
|
||||
return {
|
||||
"dataset_id": self._dataset.id,
|
||||
}
|
||||
|
||||
@property
|
||||
def dateset_id(self) -> Any:
|
||||
return self._dataset.id
|
||||
|
||||
@property
|
||||
def user_id(self) -> Any:
|
||||
return self._user_id
|
||||
|
||||
@property
|
||||
def embedding_model_name(self) -> Any:
|
||||
return self._embedding_model_name
|
||||
|
||||
@property
|
||||
def docs(self) -> Dict[str, BaseDocument]:
|
||||
document_segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self._dataset.id
|
||||
).all()
|
||||
|
||||
output = {}
|
||||
for document_segment in document_segments:
|
||||
doc_id = document_segment.index_node_id
|
||||
result = self.segment_to_dict(document_segment)
|
||||
output[doc_id] = json_to_doc(result)
|
||||
|
||||
return output
|
||||
|
||||
def add_documents(
|
||||
self, docs: Sequence[BaseDocument], allow_update: bool = True
|
||||
) -> None:
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document == self._document_id
|
||||
).scalar()
|
||||
|
||||
if max_position is None:
|
||||
max_position = 0
|
||||
|
||||
for doc in docs:
|
||||
if doc.is_doc_id_none:
|
||||
raise ValueError("doc_id not set")
|
||||
|
||||
if not isinstance(doc, Node):
|
||||
raise ValueError("doc must be a Node")
|
||||
|
||||
segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False)
|
||||
|
||||
# NOTE: doc could already exist in the store, but we overwrite it
|
||||
if not allow_update and segment_document:
|
||||
raise ValueError(
|
||||
f"doc_id {doc.get_doc_id()} already exists. "
|
||||
"Set allow_update to True to overwrite."
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text())
|
||||
|
||||
if not segment_document:
|
||||
max_position += 1
|
||||
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
document_id=self._document_id,
|
||||
index_node_id=doc.get_doc_id(),
|
||||
index_node_hash=doc.get_doc_hash(),
|
||||
position=max_position,
|
||||
content=doc.get_text(),
|
||||
word_count=len(doc.get_text()),
|
||||
tokens=tokens,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(segment_document)
|
||||
else:
|
||||
segment_document.content = doc.get_text()
|
||||
segment_document.index_node_hash = doc.get_doc_hash()
|
||||
segment_document.word_count = len(doc.get_text())
|
||||
segment_document.tokens = tokens
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def document_exists(self, doc_id: str) -> bool:
|
||||
"""Check if document exists."""
|
||||
result = self.get_document_segment(doc_id)
|
||||
return result is not None
|
||||
|
||||
def get_document(
|
||||
self, doc_id: str, raise_error: bool = True
|
||||
) -> Optional[BaseDocument]:
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
if document_segment is None:
|
||||
if raise_error:
|
||||
raise ValueError(f"doc_id {doc_id} not found.")
|
||||
else:
|
||||
return None
|
||||
|
||||
result = self.segment_to_dict(document_segment)
|
||||
return json_to_doc(result)
|
||||
|
||||
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
if document_segment is None:
|
||||
if raise_error:
|
||||
raise ValueError(f"doc_id {doc_id} not found.")
|
||||
else:
|
||||
return None
|
||||
|
||||
db.session.delete(document_segment)
|
||||
db.session.commit()
|
||||
|
||||
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
|
||||
"""Set the hash for a given doc_id."""
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
if document_segment is None:
|
||||
return None
|
||||
|
||||
document_segment.index_node_hash = doc_hash
|
||||
db.session.commit()
|
||||
|
||||
def get_document_hash(self, doc_id: str) -> Optional[str]:
|
||||
"""Get the stored hash for a document, if it exists."""
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
if document_segment is None:
|
||||
return None
|
||||
|
||||
return document_segment.index_node_hash
|
||||
|
||||
def update_docstore(self, other: "BaseDocumentStore") -> None:
|
||||
"""Update docstore.
|
||||
|
||||
Args:
|
||||
other (BaseDocumentStore): docstore to update from
|
||||
|
||||
"""
|
||||
self.add_documents(list(other.docs.values()))
|
||||
|
||||
def get_document_segment(self, doc_id: str) -> DocumentSegment:
|
||||
document_segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self._dataset.id,
|
||||
DocumentSegment.index_node_id == doc_id
|
||||
).first()
|
||||
|
||||
return document_segment
|
||||
|
||||
def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]:
|
||||
return {
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"text": segment.content,
|
||||
"__type__": Node.get_type()
|
||||
}
|
51
api/core/docstore/empty_docstore.py
Normal file
51
api/core/docstore/empty_docstore.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
from llama_index.docstore.types import BaseDocumentStore
|
||||
from llama_index.schema import BaseDocument
|
||||
|
||||
|
||||
class EmptyDocumentStore(BaseDocumentStore):
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore":
|
||||
return cls()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize to dict."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def docs(self) -> Dict[str, BaseDocument]:
|
||||
return {}
|
||||
|
||||
def add_documents(
|
||||
self, docs: Sequence[BaseDocument], allow_update: bool = True
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def document_exists(self, doc_id: str) -> bool:
|
||||
"""Check if document exists."""
|
||||
return False
|
||||
|
||||
def get_document(
|
||||
self, doc_id: str, raise_error: bool = True
|
||||
) -> Optional[BaseDocument]:
|
||||
return None
|
||||
|
||||
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
|
||||
pass
|
||||
|
||||
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
|
||||
"""Set the hash for a given doc_id."""
|
||||
pass
|
||||
|
||||
def get_document_hash(self, doc_id: str) -> Optional[str]:
|
||||
"""Get the stored hash for a document, if it exists."""
|
||||
return None
|
||||
|
||||
def update_docstore(self, other: "BaseDocumentStore") -> None:
|
||||
"""Update docstore.
|
||||
|
||||
Args:
|
||||
other (BaseDocumentStore): docstore to update from
|
||||
|
||||
"""
|
||||
self.add_documents(list(other.docs.values()))
|
Reference in New Issue
Block a user