Feat/firecrawl data source (#5232)
Co-authored-by: Nicolas <nicolascamara29@gmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
@@ -31,7 +31,7 @@ from models.dataset import (
|
||||
DocumentSegment,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.source import DataSourceBinding
|
||||
from models.source import DataSourceOauthBinding
|
||||
from services.errors.account import NoPermissionError
|
||||
from services.errors.dataset import DatasetInUseError, DatasetNameDuplicateError
|
||||
from services.errors.document import DocumentIndexingError
|
||||
@@ -48,6 +48,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
||||
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
|
||||
|
||||
|
||||
class DatasetService:
|
||||
@@ -508,18 +509,40 @@ class DocumentService:
|
||||
@staticmethod
|
||||
def retry_document(dataset_id: str, documents: list[Document]):
|
||||
for document in documents:
|
||||
# add retry flag
|
||||
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
|
||||
cache_result = redis_client.get(retry_indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise ValueError("Document is being retried, please try again later")
|
||||
# retry document indexing
|
||||
document.indexing_status = 'waiting'
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
# add retry flag
|
||||
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
|
||||
|
||||
redis_client.setex(retry_indexing_cache_key, 600, 1)
|
||||
# trigger async task
|
||||
document_ids = [document.id for document in documents]
|
||||
retry_document_indexing_task.delay(dataset_id, document_ids)
|
||||
|
||||
@staticmethod
|
||||
def sync_website_document(dataset_id: str, document: Document):
|
||||
# add sync flag
|
||||
sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id)
|
||||
cache_result = redis_client.get(sync_indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise ValueError("Document is being synced, please try again later")
|
||||
# sync document indexing
|
||||
document.indexing_status = 'waiting'
|
||||
data_source_info = document.data_source_info_dict
|
||||
data_source_info['mode'] = 'scrape'
|
||||
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
redis_client.setex(sync_indexing_cache_key, 600, 1)
|
||||
|
||||
sync_website_document_indexing_task.delay(dataset_id, document.id)
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||
if document:
|
||||
@@ -545,6 +568,9 @@ class DocumentService:
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info['pages'])
|
||||
elif document_data["data_source"]["type"] == "website_crawl":
|
||||
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||
count = len(website_info['urls'])
|
||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
@@ -683,12 +709,12 @@ class DocumentService:
|
||||
exist_document[data_source_info['notion_page_id']] = document.id
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
@@ -717,6 +743,28 @@ class DocumentService:
|
||||
# delete not selected documents
|
||||
if len(exist_document) > 0:
|
||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
elif document_data["data_source"]["type"] == "website_crawl":
|
||||
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||
urls = website_info['urls']
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
'url': url,
|
||||
'provider': website_info['provider'],
|
||||
'job_id': website_info['job_id'],
|
||||
'only_main_content': website_info.get('only_main_content', False),
|
||||
'mode': 'crawl',
|
||||
}
|
||||
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, url, batch)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
db.session.commit()
|
||||
|
||||
# trigger async task
|
||||
@@ -818,12 +866,12 @@ class DocumentService:
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
@@ -835,6 +883,17 @@ class DocumentService:
|
||||
"notion_page_icon": page['page_icon'],
|
||||
"type": page['type']
|
||||
}
|
||||
elif document_data["data_source"]["type"] == "website_crawl":
|
||||
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||
urls = website_info['urls']
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
'url': url,
|
||||
'provider': website_info['provider'],
|
||||
'job_id': website_info['job_id'],
|
||||
'only_main_content': website_info.get('only_main_content', False),
|
||||
'mode': 'crawl',
|
||||
}
|
||||
document.data_source_type = document_data["data_source"]["type"]
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.name = file_name
|
||||
@@ -873,6 +932,9 @@ class DocumentService:
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info['pages'])
|
||||
elif document_data["data_source"]["type"] == "website_crawl":
|
||||
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||
count = len(website_info['urls'])
|
||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
@@ -973,6 +1035,10 @@ class DocumentService:
|
||||
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
|
||||
'notion_info_list']:
|
||||
raise ValueError("Notion source info is required")
|
||||
if args['data_source']['type'] == 'website_crawl':
|
||||
if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
|
||||
'website_info_list']:
|
||||
raise ValueError("Website source info is required")
|
||||
|
||||
@classmethod
|
||||
def process_rule_args_validate(cls, args: dict):
|
||||
|
Reference in New Issue
Block a user