Feat: Q&A format segmentation support (#668)
Co-authored-by: jyong <718720800@qq.com> Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
@@ -1,13 +1,20 @@
|
||||
import asyncio
|
||||
import concurrent
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from multiprocessing import Process
|
||||
from typing import Optional, List, cast
|
||||
|
||||
from flask import current_app
|
||||
import openai
|
||||
from billiard.pool import Pool
|
||||
from flask import current_app, Flask
|
||||
from flask_login import current_user
|
||||
from gevent.threadpool import ThreadPoolExecutor
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
@@ -16,11 +23,13 @@ from core.data_loader.file_extractor import FileExtractor
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.docstore.dataset_docstore import DatesetDocumentStore
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.index.index import IndexBuilder
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
|
||||
from core.llm.token_calculator import TokenCalculator
|
||||
from extensions.ext_database import db
|
||||
@@ -70,7 +79,13 @@ class IndexingRunner:
|
||||
dataset_document=dataset_document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
|
||||
# new_documents = []
|
||||
# for document in documents:
|
||||
# response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
|
||||
# document_qa_list = self.format_split_text(response)
|
||||
# for result in document_qa_list:
|
||||
# document = Document(page_content=result['question'], metadata={'source': result['answer']})
|
||||
# new_documents.append(document)
|
||||
# build index
|
||||
self._build_index(
|
||||
dataset=dataset,
|
||||
@@ -91,6 +106,22 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
|
||||
matches = re.findall(regex, text, re.MULTILINE)
|
||||
|
||||
result = []
|
||||
for match in matches:
|
||||
q = match[0]
|
||||
a = match[1]
|
||||
if q and a:
|
||||
result.append({
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def run_in_splitting_status(self, dataset_document: DatasetDocument):
|
||||
"""Run the indexing process when the index_status is splitting."""
|
||||
try:
|
||||
@@ -205,7 +236,8 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
|
||||
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict,
|
||||
doc_form: str = None) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
@@ -225,7 +257,7 @@ class IndexingRunner:
|
||||
splitter = self._get_splitter(processing_rule)
|
||||
|
||||
# split to documents
|
||||
documents = self._split_to_documents(
|
||||
documents = self._split_to_documents_for_estimate(
|
||||
text_docs=text_docs,
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule
|
||||
@@ -237,7 +269,25 @@ class IndexingRunner:
|
||||
|
||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
|
||||
self.filter_string(document.page_content))
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=2000
|
||||
)
|
||||
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
|
||||
"currency": TokenCalculator.get_currency(self.embedding_model_name),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
@@ -246,7 +296,7 @@ class IndexingRunner:
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict) -> dict:
|
||||
def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
@@ -285,7 +335,7 @@ class IndexingRunner:
|
||||
splitter = self._get_splitter(processing_rule)
|
||||
|
||||
# split to documents
|
||||
documents = self._split_to_documents(
|
||||
documents = self._split_to_documents_for_estimate(
|
||||
text_docs=documents,
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule
|
||||
@@ -296,7 +346,25 @@ class IndexingRunner:
|
||||
preview_texts.append(document.page_content)
|
||||
|
||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=2000
|
||||
)
|
||||
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
|
||||
"currency": TokenCalculator.get_currency(self.embedding_model_name),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
@@ -391,7 +459,9 @@ class IndexingRunner:
|
||||
documents = self._split_to_documents(
|
||||
text_docs=text_docs,
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule
|
||||
processing_rule=processing_rule,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_form=dataset_document.doc_form
|
||||
)
|
||||
|
||||
# save node to document segment
|
||||
@@ -428,7 +498,64 @@ class IndexingRunner:
|
||||
return documents
|
||||
|
||||
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule) -> List[Document]:
|
||||
processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
all_documents = []
|
||||
for text_doc in text_docs:
|
||||
# document clean
|
||||
document_text = self._document_clean(text_doc.page_content, processing_rule)
|
||||
text_doc.page_content = document_text
|
||||
|
||||
# parse document to nodes
|
||||
documents = splitter.split_documents([text_doc])
|
||||
split_documents = []
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=2000
|
||||
)
|
||||
self.format_document(llm, documents, split_documents, document_form)
|
||||
all_documents.extend(split_documents)
|
||||
|
||||
return all_documents
|
||||
|
||||
def format_document(self, llm: StreamableOpenAI, documents: List[Document], split_documents: List, document_form: str):
|
||||
for document_node in documents:
|
||||
format_documents = []
|
||||
if document_node.page_content is None or not document_node.page_content.strip():
|
||||
return format_documents
|
||||
if document_form == 'text_model':
|
||||
# text model document
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
|
||||
format_documents.append(document_node)
|
||||
elif document_form == 'qa_model':
|
||||
try:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
qa_documents = []
|
||||
for result in document_qa_list:
|
||||
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(result['question'])
|
||||
qa_document.metadata['answer'] = result['answer']
|
||||
qa_document.metadata['doc_id'] = doc_id
|
||||
qa_document.metadata['doc_hash'] = hash
|
||||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception:
|
||||
continue
|
||||
split_documents.extend(format_documents)
|
||||
|
||||
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule) -> List[Document]:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
@@ -445,7 +572,6 @@ class IndexingRunner:
|
||||
for document in documents:
|
||||
if document.page_content is None or not document.page_content.strip():
|
||||
continue
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document.page_content)
|
||||
|
||||
@@ -487,6 +613,23 @@ class IndexingRunner:
|
||||
|
||||
return text
|
||||
|
||||
def format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式
|
||||
matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果
|
||||
|
||||
result = [] # 存储最终的结果
|
||||
for match in matches:
|
||||
q = match[0]
|
||||
a = match[1]
|
||||
if q and a:
|
||||
# 如果Q和A都存在,就将其添加到结果中
|
||||
result.append({
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
|
||||
"""
|
||||
Build the index for the document.
|
||||
|
Reference in New Issue
Block a user