Feat/firecrawl data source (#5232)

Co-authored-by: Nicolas <nicolascamara29@gmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Jyong
2024-06-15 02:46:02 +08:00
committed by GitHub
parent 918ebe1620
commit ba5f8afaa8
36 changed files with 1174 additions and 64 deletions

View File

@@ -31,7 +31,7 @@ from models.dataset import (
DocumentSegment,
)
from models.model import UploadFile
from models.source import DataSourceBinding
from models.source import DataSourceOauthBinding
from services.errors.account import NoPermissionError
from services.errors.dataset import DatasetInUseError, DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError
@@ -48,6 +48,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.retry_document_indexing_task import retry_document_indexing_task
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
class DatasetService:
@@ -508,18 +509,40 @@ class DocumentService:
@staticmethod
def retry_document(dataset_id: str, documents: list[Document]):
for document in documents:
# add retry flag
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
cache_result = redis_client.get(retry_indexing_cache_key)
if cache_result is not None:
raise ValueError("Document is being retried, please try again later")
# retry document indexing
document.indexing_status = 'waiting'
db.session.add(document)
db.session.commit()
# add retry flag
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
redis_client.setex(retry_indexing_cache_key, 600, 1)
# trigger async task
document_ids = [document.id for document in documents]
retry_document_indexing_task.delay(dataset_id, document_ids)
@staticmethod
def sync_website_document(dataset_id: str, document: Document):
# add sync flag
sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id)
cache_result = redis_client.get(sync_indexing_cache_key)
if cache_result is not None:
raise ValueError("Document is being synced, please try again later")
# sync document indexing
document.indexing_status = 'waiting'
data_source_info = document.data_source_info_dict
data_source_info['mode'] = 'scrape'
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
db.session.add(document)
db.session.commit()
redis_client.setex(sync_indexing_cache_key, 600, 1)
sync_website_document_indexing_task.delay(dataset_id, document.id)
@staticmethod
def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
if document:
@@ -545,6 +568,9 @@ class DocumentService:
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
count = count + len(notion_info['pages'])
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
count = len(website_info['urls'])
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -683,12 +709,12 @@ class DocumentService:
exist_document[data_source_info['notion_page_id']] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
@@ -717,6 +743,28 @@ class DocumentService:
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
urls = website_info['urls']
for url in urls:
data_source_info = {
'url': url,
'provider': website_info['provider'],
'job_id': website_info['job_id'],
'only_main_content': website_info.get('only_main_content', False),
'mode': 'crawl',
}
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, url, batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()
# trigger async task
@@ -818,12 +866,12 @@ class DocumentService:
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
@@ -835,6 +883,17 @@ class DocumentService:
"notion_page_icon": page['page_icon'],
"type": page['type']
}
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
urls = website_info['urls']
for url in urls:
data_source_info = {
'url': url,
'provider': website_info['provider'],
'job_id': website_info['job_id'],
'only_main_content': website_info.get('only_main_content', False),
'mode': 'crawl',
}
document.data_source_type = document_data["data_source"]["type"]
document.data_source_info = json.dumps(data_source_info)
document.name = file_name
@@ -873,6 +932,9 @@ class DocumentService:
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
count = count + len(notion_info['pages'])
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
count = len(website_info['urls'])
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -973,6 +1035,10 @@ class DocumentService:
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
'notion_info_list']:
raise ValueError("Notion source info is required")
if args['data_source']['type'] == 'website_crawl':
if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
'website_info_list']:
raise ValueError("Website source info is required")
@classmethod
def process_rule_args_validate(cls, args: dict):