From bc9efa7ea827c931a4ac4454851f2d3c8d3765fb Mon Sep 17 00:00:00 2001 From: Yongtao Huang Date: Wed, 3 Sep 2025 08:56:48 +0800 Subject: [PATCH] Refactor: use DatasourceType.XX.value instead of hardcoded (#25015) Signed-off-by: Yongtao Huang Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/workflow.py | 1 - api/controllers/console/datasets/data_source.py | 3 ++- api/controllers/console/datasets/datasets.py | 9 ++++++--- api/controllers/console/datasets/datasets_document.py | 9 +++++---- api/core/indexing_runner.py | 9 ++++++--- api/core/rag/extractor/extract_processor.py | 4 ++-- 6 files changed, 21 insertions(+), 14 deletions(-) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index e36f308bd..9f829e27f 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -526,7 +526,6 @@ class PublishedWorkflowApi(Resource): ) app_model.workflow_id = workflow.id - db.session.commit() workflow_created_at = TimestampField().format(workflow.created_at) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 6083a53be..e4d5f1be6 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -10,6 +10,7 @@ from werkzeug.exceptions import NotFound from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required from core.indexing_runner import IndexingRunner +from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db @@ -214,7 +215,7 @@ class DataSourceNotionApi(Resource): workspace_id = notion_info["workspace_id"] for page in notion_info["pages"]: extract_setting = ExtractSetting( - datasource_type="notion_import", + datasource_type=DatasourceType.NOTION.value, notion_info={ "notion_workspace_id": workspace_id, "notion_obj_id": page["page_id"], diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index a5a18e7f3..11b7b1fec 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -22,6 +22,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.plugin.entities.plugin import ModelProviderID from core.provider_manager import ProviderManager from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db @@ -422,7 +423,9 @@ class DatasetIndexingEstimateApi(Resource): if file_details: for file_detail in file_details: extract_setting = ExtractSetting( - datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"] + datasource_type=DatasourceType.FILE.value, + upload_file=file_detail, + document_model=args["doc_form"], ) extract_settings.append(extract_setting) elif args["info_list"]["data_source_type"] == "notion_import": @@ -431,7 +434,7 @@ class DatasetIndexingEstimateApi(Resource): workspace_id = notion_info["workspace_id"] for page in notion_info["pages"]: extract_setting = ExtractSetting( - datasource_type="notion_import", + datasource_type=DatasourceType.NOTION.value, notion_info={ "notion_workspace_id": workspace_id, "notion_obj_id": page["page_id"], @@ -445,7 +448,7 @@ class DatasetIndexingEstimateApi(Resource): website_info_list = args["info_list"]["website_info_list"] for url in website_info_list["urls"]: extract_setting = ExtractSetting( - datasource_type="website_crawl", + datasource_type=DatasourceType.WEBSITE.value, website_info={ "provider": website_info_list["provider"], "job_id": website_info_list["job_id"], diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 22bb81f9e..f9703f5a2 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -40,6 +40,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.plugin.impl.exc import PluginDaemonClientSideError +from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting from extensions.ext_database import db from fields.document_fields import ( @@ -425,7 +426,7 @@ class DocumentIndexingEstimateApi(DocumentResource): raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type="upload_file", upload_file=file, document_model=document.doc_form + datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form ) indexing_runner = IndexingRunner() @@ -485,13 +486,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form + datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form ) extract_settings.append(extract_setting) elif document.data_source_type == "notion_import": extract_setting = ExtractSetting( - datasource_type="notion_import", + datasource_type=DatasourceType.NOTION.value, notion_info={ "notion_workspace_id": data_source_info["notion_workspace_id"], "notion_obj_id": data_source_info["notion_page_id"], @@ -503,7 +504,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): extract_settings.append(extract_setting) elif document.data_source_type == "website_crawl": extract_setting = ExtractSetting( - datasource_type="website_crawl", + datasource_type=DatasourceType.WEBSITE.value, website_info={ "provider": data_source_info["provider"], "job_id": data_source_info["job_id"], diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 4a768618f..d31109f7a 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -19,6 +19,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore +from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_base import BaseIndexProcessor @@ -340,7 +341,9 @@ class IndexingRunner: if file_detail: extract_setting = ExtractSetting( - datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form + datasource_type=DatasourceType.FILE.value, + upload_file=file_detail, + document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) elif dataset_document.data_source_type == "notion_import": @@ -351,7 +354,7 @@ class IndexingRunner: ): raise ValueError("no notion import info found") extract_setting = ExtractSetting( - datasource_type="notion_import", + datasource_type=DatasourceType.NOTION.value, notion_info={ "notion_workspace_id": data_source_info["notion_workspace_id"], "notion_obj_id": data_source_info["notion_page_id"], @@ -371,7 +374,7 @@ class IndexingRunner: ): raise ValueError("no website import info found") extract_setting = ExtractSetting( - datasource_type="website_crawl", + datasource_type=DatasourceType.WEBSITE.value, website_info={ "provider": data_source_info["provider"], "job_id": data_source_info["job_id"], diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index e6b28b1bf..b5ea08173 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -45,7 +45,7 @@ class ExtractProcessor: cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False ) -> Union[list[Document], str]: extract_setting = ExtractSetting( - datasource_type="upload_file", upload_file=upload_file, document_model="text_model" + datasource_type=DatasourceType.FILE.value, upload_file=upload_file, document_model="text_model" ) if return_text: delimiter = "\n" @@ -76,7 +76,7 @@ class ExtractProcessor: # https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521 file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}" Path(file_path).write_bytes(response.content) - extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") + extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE.value, document_model="text_model") if return_text: delimiter = "\n" return delimiter.join(