feat: upgrade langchain (#430)

Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
John Wang
2023-06-25 16:49:14 +08:00
committed by GitHub
parent 1dee5de9b4
commit 3241e4015b
91 changed files with 2703 additions and 3153 deletions

View File

@@ -1,35 +1,34 @@
import datetime
import json
import logging
import re
import tempfile
import time
from pathlib import Path
from typing import Optional, List
import uuid
from typing import Optional, List, cast
from flask import current_app
from flask_login import current_user
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from llama_index import SimpleDirectoryReader
from llama_index.data_structs import Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.node_parser import SimpleNodeParser, NodeParser
from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
from llama_index.readers.file.markdown_parser import MarkdownParser
from core.data_source.notion import NotionPageReader
from core.index.readers.xlsx_parser import XLSXParser
from core.data_loader.file_extractor import FileExtractor
from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.index.keyword_table_index import KeywordTableIndex
from core.index.readers.html_parser import HTMLParser
from core.index.readers.markdown_parser import MarkdownParser
from core.index.readers.pdf_parser import PDFParser
from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.index.vector_index import VectorIndex
from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.error import ProviderTokenNotInitError
from core.llm.llm_builder import LLMBuilder
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule
from libs import helper
from models.dataset import Document as DatasetDocument
from models.dataset import Dataset, DocumentSegment, DatasetProcessRule
from models.model import UploadFile
from models.source import DataSourceBinding
@@ -40,135 +39,171 @@ class IndexingRunner:
self.storage = storage
self.embedding_model_name = embedding_model_name
def run(self, documents: List[Document]):
def run(self, dataset_documents: List[DatasetDocument]):
"""Run the indexing process."""
for document in documents:
for dataset_document in dataset_documents:
try:
# get dataset
dataset = Dataset.query.filter_by(
id=dataset_document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
# load file
text_docs = self._load_data(dataset_document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
# get splitter
splitter = self._get_splitter(processing_rule)
# split to documents
documents = self._step_split(
text_docs=text_docs,
splitter=splitter,
dataset=dataset,
dataset_document=dataset_document,
processing_rule=processing_rule
)
# build index
self._build_index(
dataset=dataset,
dataset_document=dataset_document,
documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = 'error'
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def run_in_splitting_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting."""
try:
# get dataset
dataset = Dataset.query.filter_by(
id=document.dataset_id
id=dataset_document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=dataset_document.id
).all()
db.session.delete(document_segments)
db.session.commit()
# load file
text_docs = self._load_data(document)
text_docs = self._load_data(dataset_document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# get splitter
splitter = self._get_splitter(processing_rule)
# split to nodes
nodes = self._step_split(
# split to documents
documents = self._step_split(
text_docs=text_docs,
node_parser=node_parser,
splitter=splitter,
dataset=dataset,
document=document,
dataset_document=dataset_document,
processing_rule=processing_rule
)
# build index
self._build_index(
dataset=dataset,
document=document,
nodes=nodes
dataset_document=dataset_document,
documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = 'error'
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def run_in_splitting_status(self, document: Document):
"""Run the indexing process when the index_status is splitting."""
# get dataset
dataset = Dataset.query.filter_by(
id=document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=document.id
).all()
db.session.delete(document_segments)
db.session.commit()
# load file
text_docs = self._load_data(document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
first()
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# split to nodes
nodes = self._step_split(
text_docs=text_docs,
node_parser=node_parser,
dataset=dataset,
document=document,
processing_rule=processing_rule
)
# build index
self._build_index(
dataset=dataset,
document=document,
nodes=nodes
)
def run_in_indexing_status(self, document: Document):
def run_in_indexing_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is indexing."""
# get dataset
dataset = Dataset.query.filter_by(
id=document.dataset_id
).first()
try:
# get dataset
dataset = Dataset.query.filter_by(
id=dataset_document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=document.id
).all()
nodes = []
if document_segments:
for document_segment in document_segments:
# transform segment to node
if document_segment.status != "completed":
relationships = {
DocumentRelationship.SOURCE: document_segment.document_id,
}
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=dataset_document.id
).all()
previous_segment = document_segment.previous_segment
if previous_segment:
relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id
documents = []
if document_segments:
for document_segment in document_segments:
# transform segment to node
if document_segment.status != "completed":
document = Document(
page_content=document_segment.content,
metadata={
"doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
)
next_segment = document_segment.next_segment
if next_segment:
relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
node = Node(
doc_id=document_segment.index_node_id,
doc_hash=document_segment.index_node_hash,
text=document_segment.content,
extra_info=None,
node_info=None,
relationships=relationships
)
nodes.append(node)
documents.append(document)
# build index
self._build_index(
dataset=dataset,
document=document,
nodes=nodes
)
# build index
self._build_index(
dataset=dataset,
dataset_document=dataset_document,
documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = 'error'
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
"""
@@ -179,28 +214,28 @@ class IndexingRunner:
total_segments = 0
for file_detail in file_details:
# load data from file
text_docs = self._load_data_from_file(file_detail)
text_docs = FileExtractor.load(file_detail)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# get splitter
splitter = self._get_splitter(processing_rule)
# split to nodes
nodes = self._split_to_nodes(
# split to documents
documents = self._split_to_documents(
text_docs=text_docs,
node_parser=node_parser,
splitter=splitter,
processing_rule=processing_rule
)
total_segments += len(nodes)
for node in nodes:
total_segments += len(documents)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(node.get_text())
preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
return {
"total_segments": total_segments,
@@ -230,35 +265,36 @@ class IndexingRunner:
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
reader = NotionPageReader(integration_token=data_source_binding.access_token)
for page in notion_info['pages']:
if page['type'] == 'page':
page_ids = [page['page_id']]
documents = reader.load_data_as_documents(page_ids=page_ids)
elif page['type'] == 'database':
documents = reader.load_data_as_documents(database_id=page['page_id'])
else:
documents = []
loader = NotionLoader(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id,
notion_obj_id=page['page_id'],
notion_page_type=page['type']
)
documents = loader.load()
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# get splitter
splitter = self._get_splitter(processing_rule)
# split to nodes
nodes = self._split_to_nodes(
# split to documents
documents = self._split_to_documents(
text_docs=documents,
node_parser=node_parser,
splitter=splitter,
processing_rule=processing_rule
)
total_segments += len(nodes)
for node in nodes:
total_segments += len(documents)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(node.get_text())
preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
return {
"total_segments": total_segments,
@@ -268,14 +304,14 @@ class IndexingRunner:
"preview": preview_texts
}
def _load_data(self, document: Document) -> List[Document]:
def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
# load file
if document.data_source_type not in ["upload_file", "notion_import"]:
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
return []
data_source_info = document.data_source_info_dict
data_source_info = dataset_document.data_source_info_dict
text_docs = []
if document.data_source_type == 'upload_file':
if dataset_document.data_source_type == 'upload_file':
if not data_source_info or 'upload_file_id' not in data_source_info:
raise ValueError("no upload file found")
@@ -283,47 +319,28 @@ class IndexingRunner:
filter(UploadFile.id == data_source_info['upload_file_id']). \
one_or_none()
text_docs = self._load_data_from_file(file_detail)
elif document.data_source_type == 'notion_import':
if not data_source_info or 'notion_page_id' not in data_source_info \
or 'notion_workspace_id' not in data_source_info:
raise ValueError("no notion page found")
workspace_id = data_source_info['notion_workspace_id']
page_id = data_source_info['notion_page_id']
page_type = data_source_info['type']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == document.tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
if page_type == 'page':
# add page last_edited_time to data_source_info
self._get_notion_page_last_edited_time(page_id, data_source_binding.access_token, document)
text_docs = self._load_page_data_from_notion(page_id, data_source_binding.access_token)
elif page_type == 'database':
# add page last_edited_time to data_source_info
self._get_notion_database_last_edited_time(page_id, data_source_binding.access_token, document)
text_docs = self._load_database_data_from_notion(page_id, data_source_binding.access_token)
text_docs = FileExtractor.load(file_detail)
elif dataset_document.data_source_type == 'notion_import':
loader = NotionLoader.from_document(dataset_document)
text_docs = loader.load()
# update document status to splitting
self._update_document_index_status(
document_id=document.id,
document_id=dataset_document.id,
after_indexing_status="splitting",
extra_update_params={
Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]),
Document.parsing_completed_at: datetime.datetime.utcnow()
DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
}
)
# replace doc id to document model id
text_docs = cast(List[Document], text_docs)
for text_doc in text_docs:
# remove invalid symbol
text_doc.text = self.filter_string(text_doc.get_text())
text_doc.doc_id = document.id
text_doc.page_content = self.filter_string(text_doc.page_content)
text_doc.metadata['document_id'] = dataset_document.id
text_doc.metadata['dataset_id'] = dataset_document.dataset_id
return text_docs
@@ -331,61 +348,7 @@ class IndexingRunner:
pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]')
return pattern.sub('', text)
def _load_data_from_file(self, upload_file: UploadFile) -> List[Document]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
self.storage.download(upload_file.key, filepath)
file_extractor = DEFAULT_FILE_EXTRACTOR.copy()
file_extractor[".markdown"] = MarkdownParser()
file_extractor[".md"] = MarkdownParser()
file_extractor[".html"] = HTMLParser()
file_extractor[".htm"] = HTMLParser()
file_extractor[".pdf"] = PDFParser({'upload_file': upload_file})
file_extractor[".xlsx"] = XLSXParser()
loader = SimpleDirectoryReader(input_files=[filepath], file_extractor=file_extractor)
text_docs = loader.load_data()
return text_docs
def _load_page_data_from_notion(self, page_id: str, access_token: str) -> List[Document]:
page_ids = [page_id]
reader = NotionPageReader(integration_token=access_token)
text_docs = reader.load_data_as_documents(page_ids=page_ids)
return text_docs
def _load_database_data_from_notion(self, database_id: str, access_token: str) -> List[Document]:
reader = NotionPageReader(integration_token=access_token)
text_docs = reader.load_data_as_documents(database_id=database_id)
return text_docs
def _get_notion_page_last_edited_time(self, page_id: str, access_token: str, document: Document):
reader = NotionPageReader(integration_token=access_token)
last_edited_time = reader.get_page_last_edited_time(page_id)
data_source_info = document.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
Document.data_source_info: json.dumps(data_source_info)
}
Document.query.filter_by(id=document.id).update(update_params)
db.session.commit()
def _get_notion_database_last_edited_time(self, page_id: str, access_token: str, document: Document):
reader = NotionPageReader(integration_token=access_token)
last_edited_time = reader.get_database_last_edited_time(page_id)
data_source_info = document.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
Document.data_source_info: json.dumps(data_source_info)
}
Document.query.filter_by(id=document.id).update(update_params)
db.session.commit()
def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
@@ -414,68 +377,83 @@ class IndexingRunner:
separators=["\n\n", "", ".", " ", ""]
)
return SimpleNodeParser(text_splitter=character_splitter, include_extra_info=True)
return character_splitter
def _step_split(self, text_docs: List[Document], node_parser: NodeParser,
dataset: Dataset, document: Document, processing_rule: DatasetProcessRule) -> List[Node]:
def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
-> List[Document]:
"""
Split the text documents into nodes and save them to the document segment.
Split the text documents into documents and save them to the document segment.
"""
nodes = self._split_to_nodes(
documents = self._split_to_documents(
text_docs=text_docs,
node_parser=node_parser,
splitter=splitter,
processing_rule=processing_rule
)
# save node to document segment
doc_store = DatesetDocumentStore(
dataset=dataset,
user_id=document.created_by,
user_id=dataset_document.created_by,
embedding_model_name=self.embedding_model_name,
document_id=document.id
document_id=dataset_document.id
)
# add document segments
doc_store.add_documents(nodes)
doc_store.add_documents(documents)
# update document status to indexing
cur_time = datetime.datetime.utcnow()
self._update_document_index_status(
document_id=document.id,
document_id=dataset_document.id,
after_indexing_status="indexing",
extra_update_params={
Document.cleaning_completed_at: cur_time,
Document.splitting_completed_at: cur_time,
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
}
)
# update segment status to indexing
self._update_segments_by_document(
document_id=document.id,
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.utcnow()
}
)
return nodes
return documents
def _split_to_nodes(self, text_docs: List[Document], node_parser: NodeParser,
processing_rule: DatasetProcessRule) -> List[Node]:
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]:
"""
Split the text documents into nodes.
"""
all_nodes = []
all_documents = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.get_text(), processing_rule)
text_doc.text = document_text
document_text = self._document_clean(text_doc.page_content, processing_rule)
text_doc.page_content = document_text
# parse document to nodes
nodes = node_parser.get_nodes_from_documents([text_doc])
nodes = [node for node in nodes if node.text is not None and node.text.strip()]
all_nodes.extend(nodes)
documents = splitter.split_documents([text_doc])
return all_nodes
split_documents = []
for document in documents:
if document.page_content is None or not document.page_content.strip():
continue
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
document.metadata['doc_id'] = doc_id
document.metadata['doc_hash'] = hash
split_documents.append(document)
all_documents.extend(split_documents)
return all_documents
def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
"""
@@ -506,37 +484,38 @@ class IndexingRunner:
return text
def _build_index(self, dataset: Dataset, document: Document, nodes: List[Node]) -> None:
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
"""
Build the index for the document.
"""
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
tokens = 0
chunk_size = 100
for i in range(0, len(nodes), chunk_size):
for i in range(0, len(documents), chunk_size):
# check document is paused
self._check_document_paused_status(document.id)
chunk_nodes = nodes[i:i + chunk_size]
self._check_document_paused_status(dataset_document.id)
chunk_documents = documents[i:i + chunk_size]
tokens += sum(
TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) for node in chunk_nodes
TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
for document in chunk_documents
)
# save vector index
if dataset.indexing_technique == "high_quality":
vector_index.add_nodes(chunk_nodes)
if vector_index:
vector_index.add_texts(chunk_documents)
# save keyword index
keyword_table_index.add_nodes(chunk_nodes)
keyword_table_index.add_texts(chunk_documents)
node_ids = [node.doc_id for node in chunk_nodes]
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id,
DocumentSegment.index_node_id.in_(node_ids),
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing"
).update({
DocumentSegment.status: "completed",
@@ -549,12 +528,12 @@ class IndexingRunner:
# update document status to completed
self._update_document_index_status(
document_id=document.id,
document_id=dataset_document.id,
after_indexing_status="completed",
extra_update_params={
Document.tokens: tokens,
Document.completed_at: datetime.datetime.utcnow(),
Document.indexing_latency: indexing_end_at - indexing_start_at,
DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: datetime.datetime.utcnow(),
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
}
)
@@ -569,25 +548,25 @@ class IndexingRunner:
"""
Update the document indexing status.
"""
count = Document.query.filter_by(id=document_id, is_paused=True).count()
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
if count > 0:
raise DocumentIsPausedException()
update_params = {
Document.indexing_status: after_indexing_status
DatasetDocument.indexing_status: after_indexing_status
}
if extra_update_params:
update_params.update(extra_update_params)
Document.query.filter_by(id=document_id).update(update_params)
DatasetDocument.query.filter_by(id=document_id).update(update_params)
db.session.commit()
def _update_segments_by_document(self, document_id: str, update_params: dict) -> None:
def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
"""
Update the document segment by document id.
"""
DocumentSegment.query.filter_by(document_id=document_id).update(update_params)
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit()