@@ -9,20 +9,20 @@ from typing import Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask_login import current_user
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import Document
|
||||
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -31,7 +31,6 @@ from libs import helper
|
||||
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import UploadFile
|
||||
from models.source import DataSourceBinding
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@@ -57,38 +56,19 @@ class IndexingRunner:
|
||||
processing_rule = db.session.query(DatasetProcessRule). \
|
||||
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
|
||||
first()
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# extract
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
|
||||
# load file
|
||||
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
|
||||
# transform
|
||||
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
# split to documents
|
||||
documents = self._step_split(
|
||||
text_docs=text_docs,
|
||||
splitter=splitter,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
self._build_index(
|
||||
# load
|
||||
self._load(
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
documents=documents
|
||||
@@ -134,39 +114,19 @@ class IndexingRunner:
|
||||
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
|
||||
first()
|
||||
|
||||
# load file
|
||||
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# extract
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
# transform
|
||||
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
# split to documents
|
||||
documents = self._step_split(
|
||||
text_docs=text_docs,
|
||||
splitter=splitter,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
|
||||
# build index
|
||||
self._build_index(
|
||||
# load
|
||||
self._load(
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
documents=documents
|
||||
@@ -220,7 +180,15 @@ class IndexingRunner:
|
||||
documents.append(document)
|
||||
|
||||
# build index
|
||||
self._build_index(
|
||||
# get the process rule
|
||||
processing_rule = db.session.query(DatasetProcessRule). \
|
||||
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
|
||||
first()
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor()
|
||||
self._load(
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
documents=documents
|
||||
@@ -239,16 +207,16 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
# check document limit
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
if features.billing.enabled:
|
||||
count = len(file_details)
|
||||
count = len(extract_settings)
|
||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
@@ -284,16 +252,18 @@ class IndexingRunner:
|
||||
total_segments = 0
|
||||
total_price = 0
|
||||
currency = 'USD'
|
||||
for file_detail in file_details:
|
||||
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
all_text_docs = []
|
||||
for extract_setting in extract_settings:
|
||||
# extract
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
||||
all_text_docs.extend(text_docs)
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"],
|
||||
rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
|
||||
# load data from file
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
@@ -305,7 +275,6 @@ class IndexingRunner:
|
||||
)
|
||||
|
||||
total_segments += len(documents)
|
||||
|
||||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
@@ -364,154 +333,8 @@ class IndexingRunner:
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
# check document limit
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
if features.billing.enabled:
|
||||
count = len(notion_info_list)
|
||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
|
||||
embedding_model_instance = None
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
else:
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
# load data from notion
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
total_price = 0
|
||||
currency = 'USD'
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise ValueError('Data source binding not found.')
|
||||
|
||||
for page in notion_info['pages']:
|
||||
loader = NotionLoader(
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page['page_id'],
|
||||
notion_page_type=page['type']
|
||||
)
|
||||
documents = loader.load()
|
||||
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"],
|
||||
rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
# split to documents
|
||||
documents = self._split_to_documents_for_estimate(
|
||||
text_docs=documents,
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
total_segments += len(documents)
|
||||
|
||||
embedding_model_type_instance = None
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
|
||||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
if indexing_technique == 'high_quality' and embedding_model_type_instance:
|
||||
tokens += embedding_model_type_instance.get_num_tokens(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
texts=[document.page_content]
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||
doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
|
||||
price_info = model_type_instance.get_price(
|
||||
model=model_instance.model,
|
||||
credentials=model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=total_segments * 2000,
|
||||
)
|
||||
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(price_info.total_amount),
|
||||
"currency": price_info.currency,
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
embedding_price_info = embedding_model_type_instance.get_price(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
total_price = '{:f}'.format(embedding_price_info.total_amount)
|
||||
currency = embedding_price_info.currency
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": total_price,
|
||||
"currency": currency,
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]:
|
||||
def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
|
||||
-> list[Document]:
|
||||
# load file
|
||||
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
|
||||
return []
|
||||
@@ -527,11 +350,27 @@ class IndexingRunner:
|
||||
one_or_none()
|
||||
|
||||
if file_detail:
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=automatic)
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=dataset_document.doc_form
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
|
||||
elif dataset_document.data_source_type == 'notion_import':
|
||||
loader = NotionLoader.from_document(dataset_document)
|
||||
text_docs = loader.load()
|
||||
|
||||
if (not data_source_info or 'notion_workspace_id' not in data_source_info
|
||||
or 'notion_page_id' not in data_source_info):
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": data_source_info['notion_workspace_id'],
|
||||
"notion_obj_id": data_source_info['notion_page_id'],
|
||||
"notion_page_type": data_source_info['notion_page_type'],
|
||||
"document": dataset_document
|
||||
},
|
||||
document_model=dataset_document.doc_form
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
|
||||
# update document status to splitting
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
@@ -545,8 +384,6 @@ class IndexingRunner:
|
||||
# replace doc id to document model id
|
||||
text_docs = cast(list[Document], text_docs)
|
||||
for text_doc in text_docs:
|
||||
# remove invalid symbol
|
||||
text_doc.page_content = self.filter_string(text_doc.page_content)
|
||||
text_doc.metadata['document_id'] = dataset_document.id
|
||||
text_doc.metadata['dataset_id'] = dataset_document.dataset_id
|
||||
|
||||
@@ -787,12 +624,12 @@ class IndexingRunner:
|
||||
for q, a in matches if q and a
|
||||
]
|
||||
|
||||
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None:
|
||||
def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
|
||||
dataset_document: DatasetDocument, documents: list[Document]) -> None:
|
||||
"""
|
||||
Build the index for the document.
|
||||
insert index and update document/segment status to completed
|
||||
"""
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
@@ -825,13 +662,8 @@ class IndexingRunner:
|
||||
)
|
||||
for document in chunk_documents
|
||||
)
|
||||
|
||||
# save vector index
|
||||
if vector_index:
|
||||
vector_index.add_texts(chunk_documents)
|
||||
|
||||
# save keyword index
|
||||
keyword_table_index.add_texts(chunk_documents)
|
||||
# load index
|
||||
index_processor.load(dataset, chunk_documents)
|
||||
|
||||
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
|
||||
db.session.query(DocumentSegment).filter(
|
||||
@@ -911,14 +743,64 @@ class IndexingRunner:
|
||||
)
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts(documents, duplicate_check=True)
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.load(dataset, documents)
|
||||
|
||||
# save keyword index
|
||||
index = IndexBuilder.get_index(dataset, 'economy')
|
||||
if index:
|
||||
index.add_texts(documents)
|
||||
def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
|
||||
text_docs: list[Document], process_rule: dict) -> list[Document]:
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
|
||||
documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
|
||||
process_rule=process_rule)
|
||||
|
||||
return documents
|
||||
|
||||
def _load_segments(self, dataset, dataset_document, documents):
|
||||
# save node to document segment
|
||||
doc_store = DatasetDocumentStore(
|
||||
dataset=dataset,
|
||||
user_id=dataset_document.created_by,
|
||||
document_id=dataset_document.id
|
||||
)
|
||||
|
||||
# add document segments
|
||||
doc_store.add_documents(documents)
|
||||
|
||||
# update document status to indexing
|
||||
cur_time = datetime.datetime.utcnow()
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="indexing",
|
||||
extra_update_params={
|
||||
DatasetDocument.cleaning_completed_at: cur_time,
|
||||
DatasetDocument.splitting_completed_at: cur_time,
|
||||
}
|
||||
)
|
||||
|
||||
# update segment status to indexing
|
||||
self._update_segments_by_document(
|
||||
dataset_document_id=dataset_document.id,
|
||||
update_params={
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.indexing_at: datetime.datetime.utcnow()
|
||||
}
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIsPausedException(Exception):
|
||||
|
Reference in New Issue
Block a user