Feat/dataset notion import (#392)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: JzoNg <jzongcode@gmail.com>
This commit is contained in:
@@ -3,7 +3,7 @@ import logging
|
||||
import datetime
|
||||
import time
|
||||
import random
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from extensions.ext_redis import redis_client
|
||||
from flask_login import current_user
|
||||
|
||||
@@ -14,10 +14,12 @@ from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from models.source import DataSourceBinding
|
||||
from services.errors.account import NoPermissionError
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
from services.errors.document import DocumentIndexingError
|
||||
from services.errors.file import FileNotExistsError
|
||||
from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
@@ -286,6 +288,24 @@ class DocumentService:
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_dataset_id(dataset_id: str) -> List[Document]:
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.enabled == True
|
||||
).all()
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_batch_documents(dataset_id: str, batch: str) -> List[Document]:
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.batch == batch,
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.tenant_id == current_user.current_tenant_id
|
||||
).all()
|
||||
|
||||
return documents
|
||||
@staticmethod
|
||||
def get_document_file_detail(file_id: str):
|
||||
file_detail = db.session.query(UploadFile). \
|
||||
filter(UploadFile.id == file_id). \
|
||||
@@ -344,9 +364,9 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
documents = Document.query.filter_by(dataset_id=dataset_id).all()
|
||||
if documents:
|
||||
return len(documents) + 1
|
||||
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||
if document:
|
||||
return document.position + 1
|
||||
else:
|
||||
return 1
|
||||
|
||||
@@ -363,9 +383,11 @@ class DocumentService:
|
||||
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
IndexBuilder.get_default_service_context(dataset.tenant_id)
|
||||
|
||||
documents = []
|
||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||
if 'original_document_id' in document_data and document_data["original_document_id"]:
|
||||
document = DocumentService.update_document_with_dataset_id(dataset, document_data, account)
|
||||
documents.append(document)
|
||||
else:
|
||||
# save process rule
|
||||
if not dataset_process_rule:
|
||||
@@ -386,46 +408,114 @@ class DocumentService:
|
||||
)
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.commit()
|
||||
|
||||
file_name = ''
|
||||
data_source_info = {}
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
file_id = document_data["data_source"]["info"]
|
||||
file = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
data_source_info = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
|
||||
# save document
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=position,
|
||||
data_source_type=document_data["data_source"]["type"],
|
||||
data_source_info=json.dumps(data_source_info),
|
||||
dataset_process_rule_id=dataset_process_rule.id,
|
||||
batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)),
|
||||
name=file_name,
|
||||
created_from=created_from,
|
||||
created_by=account.id,
|
||||
# created_api_request_id = db.Column(UUID, nullable=True)
|
||||
)
|
||||
document_ids = []
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
||||
for file_id in upload_file_list:
|
||||
file = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
|
||||
db.session.add(document)
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
data_source_info = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
document = DocumentService.save_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
data_source_info, created_from, position,
|
||||
account, file_name, batch)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
elif document_data["data_source"]["type"] == "notion_import":
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
exist_page_ids = []
|
||||
exist_document = dict()
|
||||
documents = Document.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type='notion_import',
|
||||
enabled=True
|
||||
).all()
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info['notion_page_id'])
|
||||
exist_document[data_source_info['notion_page_id']] = document.id
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise ValueError('Data source binding not found.')
|
||||
for page in notion_info['pages']:
|
||||
if page['page_id'] not in exist_page_ids:
|
||||
data_source_info = {
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_page_id": page['page_id'],
|
||||
"notion_page_icon": page['page_icon'],
|
||||
"type": page['type']
|
||||
}
|
||||
document = DocumentService.save_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
data_source_info, created_from, position,
|
||||
account, page['page_name'], batch)
|
||||
# if page['type'] == 'database':
|
||||
# document.splitting_completed_at = datetime.datetime.utcnow()
|
||||
# document.cleaning_completed_at = datetime.datetime.utcnow()
|
||||
# document.parsing_completed_at = datetime.datetime.utcnow()
|
||||
# document.completed_at = datetime.datetime.utcnow()
|
||||
# document.indexing_status = 'completed'
|
||||
# document.word_count = 0
|
||||
# document.tokens = 0
|
||||
# document.indexing_latency = 0
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
# if page['type'] != 'database':
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
else:
|
||||
exist_document.pop(page['page_id'])
|
||||
# delete not selected documents
|
||||
if len(exist_document) > 0:
|
||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
db.session.commit()
|
||||
|
||||
# trigger async task
|
||||
document_indexing_task.delay(document.dataset_id, document.id)
|
||||
document_indexing_task.delay(dataset.id, document_ids)
|
||||
|
||||
return documents, batch
|
||||
|
||||
@staticmethod
|
||||
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict,
|
||||
created_from: str, position: int, account: Account, name: str, batch: str):
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=position,
|
||||
data_source_type=data_source_type,
|
||||
data_source_info=json.dumps(data_source_info),
|
||||
dataset_process_rule_id=process_rule_id,
|
||||
batch=batch,
|
||||
name=name,
|
||||
created_from=created_from,
|
||||
created_by=account.id,
|
||||
)
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
@@ -460,20 +550,42 @@ class DocumentService:
|
||||
file_name = ''
|
||||
data_source_info = {}
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
file_id = document_data["data_source"]["info"]
|
||||
file = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
||||
for file_id in upload_file_list:
|
||||
file = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
data_source_info = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
file_name = file.name
|
||||
data_source_info = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
elif document_data["data_source"]["type"] == "notion_import":
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise ValueError('Data source binding not found.')
|
||||
for page in notion_info['pages']:
|
||||
data_source_info = {
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_page_id": page['page_id'],
|
||||
"notion_page_icon": page['page_icon'],
|
||||
"type": page['type']
|
||||
}
|
||||
document.data_source_type = document_data["data_source"]["type"]
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.name = file_name
|
||||
@@ -513,15 +625,15 @@ class DocumentService:
|
||||
db.session.add(dataset)
|
||||
db.session.flush()
|
||||
|
||||
document = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
|
||||
|
||||
cut_length = 18
|
||||
cut_name = document.name[:cut_length]
|
||||
dataset.name = cut_name + '...' if len(document.name) > cut_length else cut_name
|
||||
dataset.description = 'useful for when you want to answer queries about the ' + document.name
|
||||
cut_name = documents[0].name[:cut_length]
|
||||
dataset.name = cut_name + '...'
|
||||
dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name
|
||||
db.session.commit()
|
||||
|
||||
return dataset, document
|
||||
return dataset, documents, batch
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
@@ -552,9 +664,15 @@ class DocumentService:
|
||||
if args['data_source']['type'] not in Document.DATA_SOURCES:
|
||||
raise ValueError("Data source type is invalid")
|
||||
|
||||
if 'info_list' not in args['data_source'] or not args['data_source']['info_list']:
|
||||
raise ValueError("Data source info is required")
|
||||
|
||||
if args['data_source']['type'] == 'upload_file':
|
||||
if 'info' not in args['data_source'] or not args['data_source']['info']:
|
||||
raise ValueError("Data source info is required")
|
||||
if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['file_info_list']:
|
||||
raise ValueError("File source info is required")
|
||||
if args['data_source']['type'] == 'notion_import':
|
||||
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['notion_info_list']:
|
||||
raise ValueError("Notion source info is required")
|
||||
|
||||
@classmethod
|
||||
def process_rule_args_validate(cls, args: dict):
|
||||
@@ -624,3 +742,78 @@ class DocumentService:
|
||||
|
||||
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
|
||||
raise ValueError("Process rule segmentation max_tokens is invalid")
|
||||
|
||||
@classmethod
|
||||
def estimate_args_validate(cls, args: dict):
|
||||
if 'info_list' not in args or not args['info_list']:
|
||||
raise ValueError("Data source info is required")
|
||||
|
||||
if not isinstance(args['info_list'], dict):
|
||||
raise ValueError("Data info is invalid")
|
||||
|
||||
if 'process_rule' not in args or not args['process_rule']:
|
||||
raise ValueError("Process rule is required")
|
||||
|
||||
if not isinstance(args['process_rule'], dict):
|
||||
raise ValueError("Process rule is invalid")
|
||||
|
||||
if 'mode' not in args['process_rule'] or not args['process_rule']['mode']:
|
||||
raise ValueError("Process rule mode is required")
|
||||
|
||||
if args['process_rule']['mode'] not in DatasetProcessRule.MODES:
|
||||
raise ValueError("Process rule mode is invalid")
|
||||
|
||||
if args['process_rule']['mode'] == 'automatic':
|
||||
args['process_rule']['rules'] = {}
|
||||
else:
|
||||
if 'rules' not in args['process_rule'] or not args['process_rule']['rules']:
|
||||
raise ValueError("Process rule rules is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules'], dict):
|
||||
raise ValueError("Process rule rules is invalid")
|
||||
|
||||
if 'pre_processing_rules' not in args['process_rule']['rules'] \
|
||||
or args['process_rule']['rules']['pre_processing_rules'] is None:
|
||||
raise ValueError("Process rule pre_processing_rules is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
|
||||
raise ValueError("Process rule pre_processing_rules is invalid")
|
||||
|
||||
unique_pre_processing_rule_dicts = {}
|
||||
for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']:
|
||||
if 'id' not in pre_processing_rule or not pre_processing_rule['id']:
|
||||
raise ValueError("Process rule pre_processing_rules id is required")
|
||||
|
||||
if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES:
|
||||
raise ValueError("Process rule pre_processing_rules id is invalid")
|
||||
|
||||
if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None:
|
||||
raise ValueError("Process rule pre_processing_rules enabled is required")
|
||||
|
||||
if not isinstance(pre_processing_rule['enabled'], bool):
|
||||
raise ValueError("Process rule pre_processing_rules enabled is invalid")
|
||||
|
||||
unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule
|
||||
|
||||
args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
|
||||
|
||||
if 'segmentation' not in args['process_rule']['rules'] \
|
||||
or args['process_rule']['rules']['segmentation'] is None:
|
||||
raise ValueError("Process rule segmentation is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules']['segmentation'], dict):
|
||||
raise ValueError("Process rule segmentation is invalid")
|
||||
|
||||
if 'separator' not in args['process_rule']['rules']['segmentation'] \
|
||||
or not args['process_rule']['rules']['segmentation']['separator']:
|
||||
raise ValueError("Process rule segmentation separator is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
|
||||
raise ValueError("Process rule segmentation separator is invalid")
|
||||
|
||||
if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
|
||||
or not args['process_rule']['rules']['segmentation']['max_tokens']:
|
||||
raise ValueError("Process rule segmentation max_tokens is required")
|
||||
|
||||
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
|
||||
raise ValueError("Process rule segmentation max_tokens is invalid")
|
||||
|
Reference in New Issue
Block a user