Feat/support parent child chunk (#12092)

This commit is contained in:
Jyong
2024-12-25 19:49:07 +08:00
committed by GitHub
parent 017d7538ae
commit 9231fdbf4c
54 changed files with 2578 additions and 808 deletions

View File

@@ -8,34 +8,34 @@ import time
import uuid
from typing import Any, Optional, cast
from flask import Flask, current_app
from flask import current_app
from flask_login import current_user # type: ignore
from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config
from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail
from core.errors.error import ProviderTokenNotInitError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.rag.models.document import ChildDocument, Document
from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.feature_service import FeatureService
@@ -115,6 +115,9 @@ class IndexingRunner:
for document_segment in document_segments:
db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
processing_rule = (
@@ -183,7 +186,22 @@ class IndexingRunner:
"dataset_id": document_segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = document_segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# build index
@@ -222,7 +240,7 @@ class IndexingRunner:
doc_language: str = "English",
dataset_id: Optional[str] = None,
indexing_technique: str = "economy",
) -> dict:
) -> IndexingEstimate:
"""
Estimate the indexing for the document.
"""
@@ -258,31 +276,38 @@ class IndexingRunner:
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
preview_texts: list[str] = []
preview_texts = []
total_segments = 0
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
all_text_docs = []
for extract_setting in extract_settings:
# extract
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
all_text_docs.extend(text_docs)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
)
# get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance)
# split to documents
documents = self._split_to_documents_for_estimate(
text_docs=text_docs, splitter=splitter, processing_rule=processing_rule
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
documents = index_processor.transform(
text_docs,
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule.to_dict(),
tenant_id=current_user.current_tenant_id,
doc_language=doc_language,
preview=True,
)
total_segments += len(documents)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
if len(preview_texts) < 10:
if doc_form and doc_form == "qa_model":
preview_detail = QAPreviewDetail(
question=document.page_content, answer=document.metadata.get("answer")
)
preview_texts.append(preview_detail)
else:
preview_detail = PreviewDetail(content=document.page_content)
if document.children:
preview_detail.child_chunks = [child.page_content for child in document.children]
preview_texts.append(preview_detail)
# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
@@ -299,15 +324,8 @@ class IndexingRunner:
db.session.delete(image_file)
if doc_form and doc_form == "qa_model":
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(
current_user.current_tenant_id, preview_texts[0], doc_language
)
document_qa_list = self.format_split_text(response)
return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts}
return {"total_segments": total_segments, "preview": preview_texts}
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
@@ -401,31 +419,26 @@ class IndexingRunner:
@staticmethod
def _get_splitter(
processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance]
processing_rule_mode: str,
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional[ModelInstance],
) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
character_splitter: TextSplitter
if processing_rule.mode == "custom":
if processing_rule_mode in ["custom", "hierarchical"]:
# The user-defined segmentation rule
rules = json.loads(processing_rule.rules)
segmentation = rules["segmentation"]
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
separator = segmentation["separator"]
if separator:
separator = separator.replace("\\n", "\n")
if segmentation.get("chunk_overlap"):
chunk_overlap = segmentation["chunk_overlap"]
else:
chunk_overlap = 0
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"],
chunk_size=max_tokens,
chunk_overlap=chunk_overlap,
fixed_separator=separator,
separators=["\n\n", "", ". ", " ", ""],
@@ -443,142 +456,6 @@ class IndexingRunner:
return character_splitter
def _step_split(
self,
text_docs: list[Document],
splitter: TextSplitter,
dataset: Dataset,
dataset_document: DatasetDocument,
processing_rule: DatasetProcessRule,
) -> list[Document]:
"""
Split the text documents into documents and save them to the document segment.
"""
documents = self._split_to_documents(
text_docs=text_docs,
splitter=splitter,
processing_rule=processing_rule,
tenant_id=dataset.tenant_id,
document_form=dataset_document.doc_form,
document_language=dataset_document.doc_language,
)
# save node to document segment
doc_store = DatasetDocumentStore(
dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
)
# add document segments
doc_store.add_documents(documents)
# update document status to indexing
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="indexing",
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
},
)
# update segment status to indexing
self._update_segments_by_document(
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
},
)
return documents
def _split_to_documents(
self,
text_docs: list[Document],
splitter: TextSplitter,
processing_rule: DatasetProcessRule,
tenant_id: str,
document_form: str,
document_language: str,
) -> list[Document]:
"""
Split the text documents into nodes.
"""
all_documents: list[Document] = []
all_qa_documents: list[Document] = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.page_content, processing_rule)
text_doc.page_content = document_text
# parse document to nodes
documents = splitter.split_documents([text_doc])
split_documents = []
for document_node in documents:
if document_node.page_content.strip():
if document_node.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
document_node.page_content = remove_leading_symbols(page_content)
if document_node.page_content:
split_documents.append(document_node)
all_documents.extend(split_documents)
# processing qa document
if document_form == "qa_model":
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i : i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(
target=self.format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"tenant_id": tenant_id,
"document_node": doc,
"all_qa_documents": all_qa_documents,
"document_language": document_language,
},
)
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
return all_qa_documents
return all_documents
def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return
with flask_app.app_context():
try:
# qa model document
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(
page_content=result["question"], metadata=document_node.metadata.model_copy()
)
if qa_document.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result["question"])
qa_document.metadata["answer"] = result["answer"]
qa_document.metadata["doc_id"] = doc_id
qa_document.metadata["doc_hash"] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
logging.exception("Failed to format qa document")
all_qa_documents.extend(format_documents)
def _split_to_documents_for_estimate(
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
) -> list[Document]:
@@ -624,11 +501,11 @@ class IndexingRunner:
return document_text
@staticmethod
def format_split_text(text):
def format_split_text(text: str) -> list[QAPreviewDetail]:
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
matches = re.findall(regex, text, re.UNICODE)
return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a]
return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a]
def _load(
self,
@@ -654,13 +531,14 @@ class IndexingRunner:
indexing_start_at = time.perf_counter()
tokens = 0
chunk_size = 10
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
)
create_keyword_thread.start()
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
)
create_keyword_thread.start()
if dataset.indexing_technique == "high_quality":
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = []
@@ -680,8 +558,8 @@ class IndexingRunner:
for future in futures:
tokens += future.result()
create_keyword_thread.join()
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
create_keyword_thread.join()
indexing_end_at = time.perf_counter()
# update document status to completed
@@ -793,28 +671,6 @@ class IndexingRunner:
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit()
@staticmethod
def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset):
"""
Batch add segments index processing
"""
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
# save vector index
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents)
def _transform(
self,
index_processor: BaseIndexProcessor,
@@ -856,7 +712,7 @@ class IndexingRunner:
)
# add document segments
doc_store.add_documents(documents)
doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX)
# update document status to indexing
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)