Feat/dataset notion import (#392)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
This commit is contained in:
Jyong
2023-06-16 21:47:51 +08:00
committed by GitHub
parent f350948bde
commit 9253f72dea
96 changed files with 4479 additions and 367 deletions

View File

@@ -3,7 +3,7 @@ import logging
import datetime
import time
import random
from typing import Optional
from typing import Optional, List
from extensions.ext_redis import redis_client
from flask_login import current_user
@@ -14,10 +14,12 @@ from extensions.ext_database import db
from models.account import Account
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment
from models.model import UploadFile
from models.source import DataSourceBinding
from services.errors.account import NoPermissionError
from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
@@ -286,6 +288,24 @@ class DocumentService:
return document
@staticmethod
def get_document_by_dataset_id(dataset_id: str) -> List[Document]:
documents = db.session.query(Document).filter(
Document.dataset_id == dataset_id,
Document.enabled == True
).all()
return documents
@staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> List[Document]:
documents = db.session.query(Document).filter(
Document.batch == batch,
Document.dataset_id == dataset_id,
Document.tenant_id == current_user.current_tenant_id
).all()
return documents
@staticmethod
def get_document_file_detail(file_id: str):
file_detail = db.session.query(UploadFile). \
filter(UploadFile.id == file_id). \
@@ -344,9 +364,9 @@ class DocumentService:
@staticmethod
def get_documents_position(dataset_id):
documents = Document.query.filter_by(dataset_id=dataset_id).all()
if documents:
return len(documents) + 1
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
if document:
return document.position + 1
else:
return 1
@@ -363,9 +383,11 @@ class DocumentService:
if dataset.indexing_technique == 'high_quality':
IndexBuilder.get_default_service_context(dataset.tenant_id)
documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
if 'original_document_id' in document_data and document_data["original_document_id"]:
document = DocumentService.update_document_with_dataset_id(dataset, document_data, account)
documents.append(document)
else:
# save process rule
if not dataset_process_rule:
@@ -386,46 +408,114 @@ class DocumentService:
)
db.session.add(dataset_process_rule)
db.session.commit()
file_name = ''
data_source_info = {}
if document_data["data_source"]["type"] == "upload_file":
file_id = document_data["data_source"]["info"]
file = db.session.query(UploadFile).filter(
UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id == file_id
).first()
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
# save document
position = DocumentService.get_documents_position(dataset.id)
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=position,
data_source_type=document_data["data_source"]["type"],
data_source_info=json.dumps(data_source_info),
dataset_process_rule_id=dataset_process_rule.id,
batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)),
name=file_name,
created_from=created_from,
created_by=account.id,
# created_api_request_id = db.Column(UUID, nullable=True)
)
document_ids = []
if document_data["data_source"]["type"] == "upload_file":
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
for file_id in upload_file_list:
file = db.session.query(UploadFile).filter(
UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id == file_id
).first()
db.session.add(document)
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
data_source_info, created_from, position,
account, file_name, batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
exist_page_ids = []
exist_document = dict()
documents = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type='notion_import',
enabled=True
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info['notion_page_id'])
exist_document[data_source_info['notion_page_id']] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
for page in notion_info['pages']:
if page['page_id'] not in exist_page_ids:
data_source_info = {
"notion_workspace_id": workspace_id,
"notion_page_id": page['page_id'],
"notion_page_icon": page['page_icon'],
"type": page['type']
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
data_source_info, created_from, position,
account, page['page_name'], batch)
# if page['type'] == 'database':
# document.splitting_completed_at = datetime.datetime.utcnow()
# document.cleaning_completed_at = datetime.datetime.utcnow()
# document.parsing_completed_at = datetime.datetime.utcnow()
# document.completed_at = datetime.datetime.utcnow()
# document.indexing_status = 'completed'
# document.word_count = 0
# document.tokens = 0
# document.indexing_latency = 0
db.session.add(document)
db.session.flush()
# if page['type'] != 'database':
document_ids.append(document.id)
documents.append(document)
position += 1
else:
exist_document.pop(page['page_id'])
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
db.session.commit()
# trigger async task
document_indexing_task.delay(document.dataset_id, document.id)
document_indexing_task.delay(dataset.id, document_ids)
return documents, batch
@staticmethod
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict,
created_from: str, position: int, account: Account, name: str, batch: str):
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=position,
data_source_type=data_source_type,
data_source_info=json.dumps(data_source_info),
dataset_process_rule_id=process_rule_id,
batch=batch,
name=name,
created_from=created_from,
created_by=account.id,
)
return document
@staticmethod
@@ -460,20 +550,42 @@ class DocumentService:
file_name = ''
data_source_info = {}
if document_data["data_source"]["type"] == "upload_file":
file_id = document_data["data_source"]["info"]
file = db.session.query(UploadFile).filter(
UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id == file_id
).first()
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
for file_id in upload_file_list:
file = db.session.query(UploadFile).filter(
UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id == file_id
).first()
# raise error if file not found
if not file:
raise FileNotExistsError()
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
for page in notion_info['pages']:
data_source_info = {
"notion_workspace_id": workspace_id,
"notion_page_id": page['page_id'],
"notion_page_icon": page['page_icon'],
"type": page['type']
}
document.data_source_type = document_data["data_source"]["type"]
document.data_source_info = json.dumps(data_source_info)
document.name = file_name
@@ -513,15 +625,15 @@ class DocumentService:
db.session.add(dataset)
db.session.flush()
document = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
documents, batch = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
cut_length = 18
cut_name = document.name[:cut_length]
dataset.name = cut_name + '...' if len(document.name) > cut_length else cut_name
dataset.description = 'useful for when you want to answer queries about the ' + document.name
cut_name = documents[0].name[:cut_length]
dataset.name = cut_name + '...'
dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name
db.session.commit()
return dataset, document
return dataset, documents, batch
@classmethod
def document_create_args_validate(cls, args: dict):
@@ -552,9 +664,15 @@ class DocumentService:
if args['data_source']['type'] not in Document.DATA_SOURCES:
raise ValueError("Data source type is invalid")
if 'info_list' not in args['data_source'] or not args['data_source']['info_list']:
raise ValueError("Data source info is required")
if args['data_source']['type'] == 'upload_file':
if 'info' not in args['data_source'] or not args['data_source']['info']:
raise ValueError("Data source info is required")
if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['file_info_list']:
raise ValueError("File source info is required")
if args['data_source']['type'] == 'notion_import':
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['notion_info_list']:
raise ValueError("Notion source info is required")
@classmethod
def process_rule_args_validate(cls, args: dict):
@@ -624,3 +742,78 @@ class DocumentService:
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
raise ValueError("Process rule segmentation max_tokens is invalid")
@classmethod
def estimate_args_validate(cls, args: dict):
if 'info_list' not in args or not args['info_list']:
raise ValueError("Data source info is required")
if not isinstance(args['info_list'], dict):
raise ValueError("Data info is invalid")
if 'process_rule' not in args or not args['process_rule']:
raise ValueError("Process rule is required")
if not isinstance(args['process_rule'], dict):
raise ValueError("Process rule is invalid")
if 'mode' not in args['process_rule'] or not args['process_rule']['mode']:
raise ValueError("Process rule mode is required")
if args['process_rule']['mode'] not in DatasetProcessRule.MODES:
raise ValueError("Process rule mode is invalid")
if args['process_rule']['mode'] == 'automatic':
args['process_rule']['rules'] = {}
else:
if 'rules' not in args['process_rule'] or not args['process_rule']['rules']:
raise ValueError("Process rule rules is required")
if not isinstance(args['process_rule']['rules'], dict):
raise ValueError("Process rule rules is invalid")
if 'pre_processing_rules' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['pre_processing_rules'] is None:
raise ValueError("Process rule pre_processing_rules is required")
if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
raise ValueError("Process rule pre_processing_rules is invalid")
unique_pre_processing_rule_dicts = {}
for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']:
if 'id' not in pre_processing_rule or not pre_processing_rule['id']:
raise ValueError("Process rule pre_processing_rules id is required")
if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES:
raise ValueError("Process rule pre_processing_rules id is invalid")
if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None:
raise ValueError("Process rule pre_processing_rules enabled is required")
if not isinstance(pre_processing_rule['enabled'], bool):
raise ValueError("Process rule pre_processing_rules enabled is invalid")
unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule
args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
if 'segmentation' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['segmentation'] is None:
raise ValueError("Process rule segmentation is required")
if not isinstance(args['process_rule']['rules']['segmentation'], dict):
raise ValueError("Process rule segmentation is invalid")
if 'separator' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['separator']:
raise ValueError("Process rule segmentation separator is required")
if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
raise ValueError("Process rule segmentation separator is invalid")
if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['max_tokens']:
raise ValueError("Process rule segmentation max_tokens is required")
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
raise ValueError("Process rule segmentation max_tokens is invalid")