Feat/dify rag (#2528)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2024-02-22 23:31:57 +08:00
committed by GitHub
parent 97fe817186
commit 6c4e6bf1d6
119 changed files with 3181 additions and 5892 deletions

View File

@@ -9,20 +9,20 @@ from typing import Optional, cast
from flask import Flask, current_app
from flask_login import current_user
from langchain.schema import Document
from langchain.text_splitter import TextSplitter
from sqlalchemy.orm.exc import ObjectDeletedError
from core.data_loader.file_extractor import FileExtractor
from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatasetDocumentStore
from core.errors.error import ProviderTokenNotInitError
from core.generator.llm_generator import LLMGenerator
from core.index.index import IndexBuilder
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -31,7 +31,6 @@ from libs import helper
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from models.source import DataSourceBinding
from services.feature_service import FeatureService
@@ -57,38 +56,19 @@ class IndexingRunner:
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
# load file
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
# transform
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
# save segment
self._load_segments(dataset, dataset_document, documents)
# get embedding model instance
embedding_model_instance = None
if dataset.indexing_technique == 'high_quality':
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
# get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance)
# split to documents
documents = self._step_split(
text_docs=text_docs,
splitter=splitter,
dataset=dataset,
dataset_document=dataset_document,
processing_rule=processing_rule
)
self._build_index(
# load
self._load(
index_processor=index_processor,
dataset=dataset,
dataset_document=dataset_document,
documents=documents
@@ -134,39 +114,19 @@ class IndexingRunner:
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
# load file
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
# get embedding model instance
embedding_model_instance = None
if dataset.indexing_technique == 'high_quality':
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
# transform
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
# save segment
self._load_segments(dataset, dataset_document, documents)
# get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance)
# split to documents
documents = self._step_split(
text_docs=text_docs,
splitter=splitter,
dataset=dataset,
dataset_document=dataset_document,
processing_rule=processing_rule
)
# build index
self._build_index(
# load
self._load(
index_processor=index_processor,
dataset=dataset,
dataset_document=dataset_document,
documents=documents
@@ -220,7 +180,15 @@ class IndexingRunner:
documents.append(document)
# build index
self._build_index(
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor()
self._load(
index_processor=index_processor,
dataset=dataset,
dataset_document=dataset_document,
documents=documents
@@ -239,16 +207,16 @@ class IndexingRunner:
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
"""
Estimate the indexing for the document.
"""
# check document limit
features = FeatureService.get_features(tenant_id)
if features.billing.enabled:
count = len(file_details)
count = len(extract_settings)
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -284,16 +252,18 @@ class IndexingRunner:
total_segments = 0
total_price = 0
currency = 'USD'
for file_detail in file_details:
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
all_text_docs = []
for extract_setting in extract_settings:
# extract
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
all_text_docs.extend(text_docs)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# load data from file
text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
# get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance)
@@ -305,7 +275,6 @@ class IndexingRunner:
)
total_segments += len(documents)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
@@ -364,154 +333,8 @@ class IndexingRunner:
"preview": preview_texts
}
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
"""
Estimate the indexing for the document.
"""
# check document limit
features = FeatureService.get_features(tenant_id)
if features.billing.enabled:
count = len(notion_info_list)
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
embedding_model_instance = None
if dataset_id:
dataset = Dataset.query.filter_by(
id=dataset_id
).first()
if not dataset:
raise ValueError('Dataset not found.')
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
if indexing_technique == 'high_quality':
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING
)
# load data from notion
tokens = 0
preview_texts = []
total_segments = 0
total_price = 0
currency = 'USD'
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
for page in notion_info['pages']:
loader = NotionLoader(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id,
notion_obj_id=page['page_id'],
notion_page_type=page['type']
)
documents = loader.load()
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance)
# split to documents
documents = self._split_to_documents_for_estimate(
text_docs=documents,
splitter=splitter,
processing_rule=processing_rule
)
total_segments += len(documents)
embedding_model_type_instance = None
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' and embedding_model_type_instance:
tokens += embedding_model_type_instance.get_num_tokens(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
texts=[document.page_content]
)
if doc_form and doc_form == 'qa_model':
model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM
)
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language)
document_qa_list = self.format_split_text(response)
price_info = model_type_instance.get_price(
model=model_instance.model,
credentials=model_instance.credentials,
price_type=PriceType.INPUT,
tokens=total_segments * 2000,
)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(price_info.total_amount),
"currency": price_info.currency,
"qa_preview": document_qa_list,
"preview": preview_texts
}
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
embedding_price_info = embedding_model_type_instance.get_price(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
total_price = '{:f}'.format(embedding_price_info.total_amount)
currency = embedding_price_info.currency
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": total_price,
"currency": currency,
"preview": preview_texts
}
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]:
def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
-> list[Document]:
# load file
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
return []
@@ -527,11 +350,27 @@ class IndexingRunner:
one_or_none()
if file_detail:
text_docs = FileExtractor.load(file_detail, is_automatic=automatic)
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file_detail,
document_model=dataset_document.doc_form
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
elif dataset_document.data_source_type == 'notion_import':
loader = NotionLoader.from_document(dataset_document)
text_docs = loader.load()
if (not data_source_info or 'notion_workspace_id' not in data_source_info
or 'notion_page_id' not in data_source_info):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'],
"notion_page_type": data_source_info['notion_page_type'],
"document": dataset_document
},
document_model=dataset_document.doc_form
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
# update document status to splitting
self._update_document_index_status(
document_id=dataset_document.id,
@@ -545,8 +384,6 @@ class IndexingRunner:
# replace doc id to document model id
text_docs = cast(list[Document], text_docs)
for text_doc in text_docs:
# remove invalid symbol
text_doc.page_content = self.filter_string(text_doc.page_content)
text_doc.metadata['document_id'] = dataset_document.id
text_doc.metadata['dataset_id'] = dataset_document.dataset_id
@@ -787,12 +624,12 @@ class IndexingRunner:
for q, a in matches if q and a
]
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None:
def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
dataset_document: DatasetDocument, documents: list[Document]) -> None:
"""
Build the index for the document.
insert index and update document/segment status to completed
"""
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
embedding_model_instance = None
if dataset.indexing_technique == 'high_quality':
embedding_model_instance = self.model_manager.get_model_instance(
@@ -825,13 +662,8 @@ class IndexingRunner:
)
for document in chunk_documents
)
# save vector index
if vector_index:
vector_index.add_texts(chunk_documents)
# save keyword index
keyword_table_index.add_texts(chunk_documents)
# load index
index_processor.load(dataset, chunk_documents)
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter(
@@ -911,14 +743,64 @@ class IndexingRunner:
)
documents.append(document)
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts(documents, duplicate_check=True)
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts(documents)
def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
text_docs: list[Document], process_rule: dict) -> list[Document]:
# get embedding model instance
embedding_model_instance = None
if dataset.indexing_technique == 'high_quality':
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
process_rule=process_rule)
return documents
def _load_segments(self, dataset, dataset_document, documents):
# save node to document segment
doc_store = DatasetDocumentStore(
dataset=dataset,
user_id=dataset_document.created_by,
document_id=dataset_document.id
)
# add document segments
doc_store.add_documents(documents)
# update document status to indexing
cur_time = datetime.datetime.utcnow()
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="indexing",
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
}
)
# update segment status to indexing
self._update_segments_by_document(
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.utcnow()
}
)
pass
class DocumentIsPausedException(Exception):