chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -12,10 +12,10 @@ from models.dataset import Dataset, DocumentSegment
class DatasetDocumentStore:
def __init__(
self,
dataset: Dataset,
user_id: str,
document_id: Optional[str] = None,
self,
dataset: Dataset,
user_id: str,
document_id: Optional[str] = None,
):
self._dataset = dataset
self._user_id = user_id
@@ -41,9 +41,9 @@ class DatasetDocumentStore:
@property
def docs(self) -> dict[str, Document]:
document_segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id
).all()
document_segments = (
db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all()
)
output = {}
for document_segment in document_segments:
@@ -55,48 +55,45 @@ class DatasetDocumentStore:
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
},
)
return output
def add_documents(
self, docs: Sequence[Document], allow_update: bool = True
) -> None:
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == self._document_id
).scalar()
def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None:
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == self._document_id)
.scalar()
)
if max_position is None:
max_position = 0
embedding_model = None
if self._dataset.indexing_technique == 'high_quality':
if self._dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model
model=self._dataset.embedding_model,
)
for doc in docs:
if not isinstance(doc, Document):
raise ValueError("doc must be a Document")
segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id'])
segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"])
# NOTE: doc could already exist in the store, but we overwrite it
if not allow_update and segment_document:
raise ValueError(
f"doc_id {doc.metadata['doc_id']} already exists. "
"Set allow_update to True to overwrite."
f"doc_id {doc.metadata['doc_id']} already exists. " "Set allow_update to True to overwrite."
)
# calc embedding use tokens
if embedding_model:
tokens = embedding_model.get_text_embedding_num_tokens(
texts=[doc.page_content]
)
tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content])
else:
tokens = 0
@@ -107,8 +104,8 @@ class DatasetDocumentStore:
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
index_node_id=doc.metadata['doc_id'],
index_node_hash=doc.metadata['doc_hash'],
index_node_id=doc.metadata["doc_id"],
index_node_hash=doc.metadata["doc_hash"],
position=max_position,
content=doc.page_content,
word_count=len(doc.page_content),
@@ -116,15 +113,15 @@ class DatasetDocumentStore:
enabled=False,
created_by=self._user_id,
)
if doc.metadata.get('answer'):
segment_document.answer = doc.metadata.pop('answer', '')
if doc.metadata.get("answer"):
segment_document.answer = doc.metadata.pop("answer", "")
db.session.add(segment_document)
else:
segment_document.content = doc.page_content
if doc.metadata.get('answer'):
segment_document.answer = doc.metadata.pop('answer', '')
segment_document.index_node_hash = doc.metadata['doc_hash']
if doc.metadata.get("answer"):
segment_document.answer = doc.metadata.pop("answer", "")
segment_document.index_node_hash = doc.metadata["doc_hash"]
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
@@ -135,9 +132,7 @@ class DatasetDocumentStore:
result = self.get_document_segment(doc_id)
return result is not None
def get_document(
self, doc_id: str, raise_error: bool = True
) -> Optional[Document]:
def get_document(self, doc_id: str, raise_error: bool = True) -> Optional[Document]:
document_segment = self.get_document_segment(doc_id)
if document_segment is None:
@@ -153,7 +148,7 @@ class DatasetDocumentStore:
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
},
)
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
@@ -188,9 +183,10 @@ class DatasetDocumentStore:
return document_segment.index_node_hash
def get_document_segment(self, doc_id: str) -> DocumentSegment:
document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id,
DocumentSegment.index_node_id == doc_id
).first()
document_segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.first()
)
return document_segment