From 9253f72dea646d6596c6347084bad16913f969e9 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 16 Jun 2023 21:47:51 +0800 Subject: [PATCH] Feat/dataset notion import (#392) Co-authored-by: StyleZhang Co-authored-by: JzoNg --- api/.env.example | 1 + api/app.py | 2 +- api/config.py | 3 + api/controllers/console/__init__.py | 4 +- .../console/auth/data_source_oauth.py | 95 +++++ .../console/datasets/data_source.py | 303 +++++++++++++++ api/controllers/console/datasets/datasets.py | 83 +++- .../console/datasets/datasets_document.py | 205 +++++++++- .../service_api/dataset/document.py | 10 +- api/core/data_source/notion.py | 367 ++++++++++++++++++ api/core/indexing_runner.py | 252 +++++++++--- api/libs/oauth.py | 7 + api/libs/oauth_data_source.py | 256 ++++++++++++ ...08af0a69ccefbb59fa80c778efee300bb780980.py | 46 +++ api/models/dataset.py | 4 +- api/models/source.py | 21 + api/services/dataset_service.py | 311 ++++++++++++--- api/tasks/clean_notion_document_task.py | 58 +++ api/tasks/document_indexing_sync_task.py | 109 ++++++ api/tasks/document_indexing_task.py | 30 +- api/tasks/document_indexing_update_task.py | 2 +- api/tasks/recover_document_indexing_task.py | 2 +- .../[datasetId]/layout.tsx | 3 +- web/app/components/app-sidebar/basic.tsx | 16 +- web/app/components/app-sidebar/index.tsx | 3 +- .../components/base/checkbox/assets/check.svg | 3 + .../components/base/checkbox/index.module.css | 9 + web/app/components/base/checkbox/index.tsx | 19 + .../base/notion-icon/index.module.css | 6 + web/app/components/base/notion-icon/index.tsx | 58 +++ .../notion-page-selector/assets/clear.svg | 3 + .../assets/down-arrow.svg | 3 + .../assets/notion-empty-page.svg | 3 + .../assets/notion-page.svg | 3 + .../notion-page-selector/assets/search.svg | 5 + .../notion-page-selector/assets/setting.svg | 11 + .../base/notion-page-selector/base.module.css | 4 + .../base/notion-page-selector/base.tsx | 141 +++++++ .../base/notion-page-selector/index.tsx | 2 + .../index.module.css | 28 ++ .../notion-page-selector-modal/index.tsx | 62 +++ .../page-selector/index.module.css | 17 + .../page-selector/index.tsx | 299 ++++++++++++++ .../search-input/index.module.css | 15 + .../search-input/index.tsx | 42 ++ .../workspace-selector/index.module.css | 9 + .../workspace-selector/index.tsx | 84 ++++ .../components/base/progress-bar/index.tsx | 20 + .../datasets/create/assets/Icon-3-dots.svg | 3 + .../datasets/create/assets/normal.svg | 4 + .../datasets/create/assets/star.svg | 11 + .../create/embedding-process/index.module.css | 111 ++++++ .../create/embedding-process/index.tsx | 242 ++++++++++++ .../create/file-preview/index.module.css | 3 + .../datasets/create/file-preview/index.tsx | 27 +- web/app/components/datasets/create/index.tsx | 48 ++- .../notion-page-preview/index.module.css | 54 +++ .../create/notion-page-preview/index.tsx | 75 ++++ .../datasets/create/step-one/index.module.css | 50 +++ .../datasets/create/step-one/index.tsx | 129 ++++-- .../datasets/create/step-three/index.tsx | 23 +- .../datasets/create/step-two/index.module.css | 37 +- .../datasets/create/step-two/index.tsx | 125 +++++- .../datasets/documents/detail/index.tsx | 42 +- .../documents/detail/settings/index.tsx | 13 +- .../components/datasets/documents/index.tsx | 128 +++++- .../components/datasets/documents/list.tsx | 61 ++- .../datasets/documents/style.module.css | 3 + .../header/account-dropdown/index.tsx | 2 +- .../data-source-notion/index.tsx | 102 +++++ .../operate/index.module.css | 14 + .../data-source-notion/operate/index.tsx | 107 +++++ .../data-source-notion/style.module.css | 12 + .../data-source-page/index.module.css | 0 .../data-source-page/index.tsx | 17 + .../header/account-setting/index.module.css | 10 + .../header/account-setting/index.tsx | 31 +- .../account-setting/members-page/index.tsx | 18 +- .../header/assets/data-source-blue.svg | 3 + .../components/header/assets/data-source.svg | 3 + web/app/components/header/assets/file.svg | 3 + web/app/components/header/assets/notion.svg | 12 + web/app/components/header/assets/sync.svg | 3 + web/app/components/header/assets/trash.svg | 3 + .../header/nav/nav-selector/index.tsx | 8 +- web/context/dataset-detail.ts | 4 +- web/i18n/lang/common.en.ts | 24 ++ web/i18n/lang/common.zh.ts | 24 ++ web/i18n/lang/dataset-creation.en.ts | 10 +- web/i18n/lang/dataset-creation.zh.ts | 10 +- web/i18n/lang/dataset-documents.en.ts | 2 + web/i18n/lang/dataset-documents.zh.ts | 2 + web/models/common.ts | 32 ++ web/models/datasets.ts | 91 +++-- web/service/common.ts | 21 +- web/service/datasets.ts | 50 ++- 96 files changed, 4479 insertions(+), 367 deletions(-) create mode 100644 api/controllers/console/auth/data_source_oauth.py create mode 100644 api/controllers/console/datasets/data_source.py create mode 100644 api/core/data_source/notion.py create mode 100644 api/libs/oauth_data_source.py create mode 100644 api/migrations/versions/e32f6ccb87c6_e08af0a69ccefbb59fa80c778efee300bb780980.py create mode 100644 api/models/source.py create mode 100644 api/tasks/clean_notion_document_task.py create mode 100644 api/tasks/document_indexing_sync_task.py create mode 100644 web/app/components/base/checkbox/assets/check.svg create mode 100644 web/app/components/base/checkbox/index.module.css create mode 100644 web/app/components/base/checkbox/index.tsx create mode 100644 web/app/components/base/notion-icon/index.module.css create mode 100644 web/app/components/base/notion-icon/index.tsx create mode 100644 web/app/components/base/notion-page-selector/assets/clear.svg create mode 100644 web/app/components/base/notion-page-selector/assets/down-arrow.svg create mode 100644 web/app/components/base/notion-page-selector/assets/notion-empty-page.svg create mode 100644 web/app/components/base/notion-page-selector/assets/notion-page.svg create mode 100644 web/app/components/base/notion-page-selector/assets/search.svg create mode 100644 web/app/components/base/notion-page-selector/assets/setting.svg create mode 100644 web/app/components/base/notion-page-selector/base.module.css create mode 100644 web/app/components/base/notion-page-selector/base.tsx create mode 100644 web/app/components/base/notion-page-selector/index.tsx create mode 100644 web/app/components/base/notion-page-selector/notion-page-selector-modal/index.module.css create mode 100644 web/app/components/base/notion-page-selector/notion-page-selector-modal/index.tsx create mode 100644 web/app/components/base/notion-page-selector/page-selector/index.module.css create mode 100644 web/app/components/base/notion-page-selector/page-selector/index.tsx create mode 100644 web/app/components/base/notion-page-selector/search-input/index.module.css create mode 100644 web/app/components/base/notion-page-selector/search-input/index.tsx create mode 100644 web/app/components/base/notion-page-selector/workspace-selector/index.module.css create mode 100644 web/app/components/base/notion-page-selector/workspace-selector/index.tsx create mode 100644 web/app/components/base/progress-bar/index.tsx create mode 100644 web/app/components/datasets/create/assets/Icon-3-dots.svg create mode 100644 web/app/components/datasets/create/assets/normal.svg create mode 100644 web/app/components/datasets/create/assets/star.svg create mode 100644 web/app/components/datasets/create/embedding-process/index.module.css create mode 100644 web/app/components/datasets/create/embedding-process/index.tsx create mode 100644 web/app/components/datasets/create/notion-page-preview/index.module.css create mode 100644 web/app/components/datasets/create/notion-page-preview/index.tsx create mode 100644 web/app/components/header/account-setting/data-source-page/data-source-notion/index.tsx create mode 100644 web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.module.css create mode 100644 web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.tsx create mode 100644 web/app/components/header/account-setting/data-source-page/data-source-notion/style.module.css create mode 100644 web/app/components/header/account-setting/data-source-page/index.module.css create mode 100644 web/app/components/header/account-setting/data-source-page/index.tsx create mode 100644 web/app/components/header/assets/data-source-blue.svg create mode 100644 web/app/components/header/assets/data-source.svg create mode 100644 web/app/components/header/assets/file.svg create mode 100644 web/app/components/header/assets/notion.svg create mode 100644 web/app/components/header/assets/sync.svg create mode 100644 web/app/components/header/assets/trash.svg diff --git a/api/.env.example b/api/.env.example index b437beabd..b9c819535 100644 --- a/api/.env.example +++ b/api/.env.example @@ -22,6 +22,7 @@ CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 # redis configuration REDIS_HOST=localhost REDIS_PORT=6379 +REDIS_USERNAME: '' REDIS_PASSWORD=difyai123456 REDIS_DB=0 diff --git a/api/app.py b/api/app.py index 2c6544c0c..4e217c6aa 100644 --- a/api/app.py +++ b/api/app.py @@ -20,7 +20,7 @@ from extensions.ext_database import db from extensions.ext_login import login_manager # DO NOT REMOVE BELOW -from models import model, account, dataset, web, task +from models import model, account, dataset, web, task, source from events import event_handlers # DO NOT REMOVE ABOVE diff --git a/api/config.py b/api/config.py index 80c5cbb67..e62f3b29a 100644 --- a/api/config.py +++ b/api/config.py @@ -187,6 +187,9 @@ class Config: # For temp use only # set default LLM provider, default is 'openai', support `azure_openai` self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') + # notion import setting + self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') + self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') class CloudEditionConfig(Config): diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 9c68e8ecd..f3d245895 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -12,10 +12,10 @@ from . import setup, version, apikey, admin from .app import app, site, completion, model_config, statistic, conversation, message, generator # Import auth controllers -from .auth import login, oauth +from .auth import login, oauth, data_source_oauth # Import datasets controllers -from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing +from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source # Import workspace controllers from .workspace import workspace, members, providers, account diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py new file mode 100644 index 000000000..0ac6af585 --- /dev/null +++ b/api/controllers/console/auth/data_source_oauth.py @@ -0,0 +1,95 @@ +import logging +from datetime import datetime +from typing import Optional + +import flask_login +import requests +from flask import request, redirect, current_app, session +from flask_login import current_user, login_required +from flask_restful import Resource +from werkzeug.exceptions import Forbidden +from libs.oauth_data_source import NotionOAuth +from controllers.console import api +from ..setup import setup_required +from ..wraps import account_initialization_required + + +def get_oauth_providers(): + with current_app.app_context(): + notion_oauth = NotionOAuth(client_id=current_app.config.get('NOTION_CLIENT_ID'), + client_secret=current_app.config.get( + 'NOTION_CLIENT_SECRET'), + redirect_uri=current_app.config.get( + 'CONSOLE_URL') + '/console/api/oauth/data-source/callback/notion') + + OAUTH_PROVIDERS = { + 'notion': notion_oauth + } + return OAUTH_PROVIDERS + + +class OAuthDataSource(Resource): + def get(self, provider: str): + # The role of the current user in the table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() + with current_app.app_context(): + oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) + print(vars(oauth_provider)) + if not oauth_provider: + return {'error': 'Invalid provider'}, 400 + + auth_url = oauth_provider.get_authorization_url() + return redirect(auth_url) + + +class OAuthDataSourceCallback(Resource): + def get(self, provider: str): + OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() + with current_app.app_context(): + oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) + if not oauth_provider: + return {'error': 'Invalid provider'}, 400 + if 'code' in request.args: + code = request.args.get('code') + try: + oauth_provider.get_access_token(code) + except requests.exceptions.HTTPError as e: + logging.exception( + f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") + return {'error': 'OAuth data source process failed'}, 400 + + return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success') + elif 'error' in request.args: + error = request.args.get('error') + return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source={error}') + else: + return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=access_denied') + + +class OAuthDataSourceSync(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider, binding_id): + provider = str(provider) + binding_id = str(binding_id) + OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() + with current_app.app_context(): + oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) + if not oauth_provider: + return {'error': 'Invalid provider'}, 400 + try: + oauth_provider.sync_data_source(binding_id) + except requests.exceptions.HTTPError as e: + logging.exception( + f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") + return {'error': 'OAuth data source process failed'}, 400 + + return {'result': 'success'}, 200 + + +api.add_resource(OAuthDataSource, '/oauth/data-source/') +api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/') +api.add_resource(OAuthDataSourceSync, '/oauth/data-source///sync') diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py new file mode 100644 index 000000000..f0efc0504 --- /dev/null +++ b/api/controllers/console/datasets/data_source.py @@ -0,0 +1,303 @@ +import datetime +import json + +from cachetools import TTLCache +from flask import request, current_app +from flask_login import login_required, current_user +from flask_restful import Resource, marshal_with, fields, reqparse, marshal +from werkzeug.exceptions import NotFound + +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.data_source.notion import NotionPageReader +from core.indexing_runner import IndexingRunner +from extensions.ext_database import db +from libs.helper import TimestampField +from libs.oauth_data_source import NotionOAuth +from models.dataset import Document +from models.source import DataSourceBinding +from services.dataset_service import DatasetService, DocumentService +from tasks.document_indexing_sync_task import document_indexing_sync_task + +cache = TTLCache(maxsize=None, ttl=30) + +FILE_SIZE_LIMIT = 15 * 1024 * 1024 # 15MB +ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm'] +PREVIEW_WORDS_LIMIT = 3000 + + +class DataSourceApi(Resource): + integrate_icon_fields = { + 'type': fields.String, + 'url': fields.String, + 'emoji': fields.String + } + integrate_page_fields = { + 'page_name': fields.String, + 'page_id': fields.String, + 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), + 'parent_id': fields.String, + 'type': fields.String + } + integrate_workspace_fields = { + 'workspace_name': fields.String, + 'workspace_id': fields.String, + 'workspace_icon': fields.String, + 'pages': fields.List(fields.Nested(integrate_page_fields)), + 'total': fields.Integer + } + integrate_fields = { + 'id': fields.String, + 'provider': fields.String, + 'created_at': TimestampField, + 'is_bound': fields.Boolean, + 'disabled': fields.Boolean, + 'link': fields.String, + 'source_info': fields.Nested(integrate_workspace_fields) + } + integrate_list_fields = { + 'data': fields.List(fields.Nested(integrate_fields)), + } + + @setup_required + @login_required + @account_initialization_required + @marshal_with(integrate_list_fields) + def get(self): + # get workspace data source integrates + data_source_integrates = db.session.query(DataSourceBinding).filter( + DataSourceBinding.tenant_id == current_user.current_tenant_id, + DataSourceBinding.disabled == False + ).all() + + base_url = request.url_root.rstrip('/') + data_source_oauth_base_path = "/console/api/oauth/data-source" + providers = ["notion"] + + integrate_data = [] + for provider in providers: + # existing_integrate = next((ai for ai in data_source_integrates if ai.provider == provider), None) + existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates) + if existing_integrates: + for existing_integrate in list(existing_integrates): + integrate_data.append({ + 'id': existing_integrate.id, + 'provider': provider, + 'created_at': existing_integrate.created_at, + 'is_bound': True, + 'disabled': existing_integrate.disabled, + 'source_info': existing_integrate.source_info, + 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' + }) + else: + integrate_data.append({ + 'id': None, + 'provider': provider, + 'created_at': None, + 'source_info': None, + 'is_bound': False, + 'disabled': None, + 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' + }) + return {'data': integrate_data}, 200 + + @setup_required + @login_required + @account_initialization_required + def patch(self, binding_id, action): + binding_id = str(binding_id) + action = str(action) + data_source_binding = DataSourceBinding.query.filter_by( + id=binding_id + ).first() + if data_source_binding is None: + raise NotFound('Data source binding not found.') + # enable binding + if action == 'enable': + if data_source_binding.disabled: + data_source_binding.disabled = False + data_source_binding.updated_at = datetime.datetime.utcnow() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError('Data source is not disabled.') + # disable binding + if action == 'disable': + if not data_source_binding.disabled: + data_source_binding.disabled = True + data_source_binding.updated_at = datetime.datetime.utcnow() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError('Data source is disabled.') + return {'result': 'success'}, 200 + + +class DataSourceNotionListApi(Resource): + integrate_icon_fields = { + 'type': fields.String, + 'url': fields.String, + 'emoji': fields.String + } + integrate_page_fields = { + 'page_name': fields.String, + 'page_id': fields.String, + 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), + 'is_bound': fields.Boolean, + 'parent_id': fields.String, + 'type': fields.String + } + integrate_workspace_fields = { + 'workspace_name': fields.String, + 'workspace_id': fields.String, + 'workspace_icon': fields.String, + 'pages': fields.List(fields.Nested(integrate_page_fields)) + } + integrate_notion_info_list_fields = { + 'notion_info': fields.List(fields.Nested(integrate_workspace_fields)), + } + + @setup_required + @login_required + @account_initialization_required + @marshal_with(integrate_notion_info_list_fields) + def get(self): + dataset_id = request.args.get('dataset_id', default=None, type=str) + exist_page_ids = [] + # import notion in the exist dataset + if dataset_id: + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + if dataset.data_source_type != 'notion_import': + raise ValueError('Dataset is not notion type.') + documents = Document.query.filter_by( + dataset_id=dataset_id, + tenant_id=current_user.current_tenant_id, + data_source_type='notion_import', + enabled=True + ).all() + if documents: + for document in documents: + data_source_info = json.loads(document.data_source_info) + exist_page_ids.append(data_source_info['notion_page_id']) + # get all authorized pages + data_source_bindings = DataSourceBinding.query.filter_by( + tenant_id=current_user.current_tenant_id, + provider='notion', + disabled=False + ).all() + if not data_source_bindings: + return { + 'notion_info': [] + }, 200 + pre_import_info_list = [] + for data_source_binding in data_source_bindings: + source_info = data_source_binding.source_info + pages = source_info['pages'] + # Filter out already bound pages + for page in pages: + if page['page_id'] in exist_page_ids: + page['is_bound'] = True + else: + page['is_bound'] = False + pre_import_info = { + 'workspace_name': source_info['workspace_name'], + 'workspace_icon': source_info['workspace_icon'], + 'workspace_id': source_info['workspace_id'], + 'pages': pages, + } + pre_import_info_list.append(pre_import_info) + return { + 'notion_info': pre_import_info_list + }, 200 + + +class DataSourceNotionApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self, workspace_id, page_id, page_type): + workspace_id = str(workspace_id) + page_id = str(page_id) + data_source_binding = DataSourceBinding.query.filter( + db.and_( + DataSourceBinding.tenant_id == current_user.current_tenant_id, + DataSourceBinding.provider == 'notion', + DataSourceBinding.disabled == False, + DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' + ) + ).first() + if not data_source_binding: + raise NotFound('Data source binding not found.') + reader = NotionPageReader(integration_token=data_source_binding.access_token) + if page_type == 'page': + page_content = reader.read_page(page_id) + elif page_type == 'database': + page_content = reader.query_database_data(page_id) + else: + page_content = "" + return { + 'content': page_content + }, 200 + + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') + parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') + args = parser.parse_args() + # validate args + DocumentService.estimate_args_validate(args) + indexing_runner = IndexingRunner() + response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule']) + return response, 200 + + +class DataSourceNotionDatasetSyncApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id): + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + + documents = DocumentService.get_document_by_dataset_id(dataset_id_str) + for document in documents: + document_indexing_sync_task.delay(dataset_id_str, document.id) + return 200 + + +class DataSourceNotionDocumentSyncApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id): + dataset_id_str = str(dataset_id) + document_id_str = str(document_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + + document = DocumentService.get_document(dataset_id_str, document_id_str) + if document is None: + raise NotFound("Document not found.") + document_indexing_sync_task.delay(dataset_id_str, document_id_str) + return 200 + + +api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates//') +api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages') +api.add_resource(DataSourceNotionApi, + '/notion/workspaces//pages///preview', + '/datasets/notion-indexing-estimate') +api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets//notion/sync') +api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets//documents//notion/sync') diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 2defcec6f..b7898d966 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -12,8 +12,9 @@ from controllers.console.wraps import account_initialization_required from core.indexing_runner import IndexingRunner from libs.helper import TimestampField from extensions.ext_database import db +from models.dataset import DocumentSegment, Document from models.model import UploadFile -from services.dataset_service import DatasetService +from services.dataset_service import DatasetService, DocumentService dataset_detail_fields = { 'id': fields.String, @@ -217,17 +218,31 @@ class DatasetIndexingEstimateApi(Resource): @login_required @account_initialization_required def post(self): - segment_rule = request.get_json() - file_detail = db.session.query(UploadFile).filter( - UploadFile.tenant_id == current_user.current_tenant_id, - UploadFile.id == segment_rule["file_id"] - ).first() + parser = reqparse.RequestParser() + parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') + parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') + args = parser.parse_args() + # validate args + DocumentService.estimate_args_validate(args) + if args['info_list']['data_source_type'] == 'upload_file': + file_ids = args['info_list']['file_info_list']['file_ids'] + file_details = db.session.query(UploadFile).filter( + UploadFile.tenant_id == current_user.current_tenant_id, + UploadFile.id.in_(file_ids) + ).all() - if file_detail is None: - raise NotFound("File not found.") + if file_details is None: + raise NotFound("File not found.") - indexing_runner = IndexingRunner() - response = indexing_runner.indexing_estimate(file_detail, segment_rule['process_rule']) + indexing_runner = IndexingRunner() + response = indexing_runner.file_indexing_estimate(file_details, args['process_rule']) + elif args['info_list']['data_source_type'] == 'notion_import': + + indexing_runner = IndexingRunner() + response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'], + args['process_rule']) + else: + raise ValueError('Data source type not support') return response, 200 @@ -274,8 +289,54 @@ class DatasetRelatedAppListApi(Resource): }, 200 +class DatasetIndexingStatusApi(Resource): + document_status_fields = { + 'id': fields.String, + 'indexing_status': fields.String, + 'processing_started_at': TimestampField, + 'parsing_completed_at': TimestampField, + 'cleaning_completed_at': TimestampField, + 'splitting_completed_at': TimestampField, + 'completed_at': TimestampField, + 'paused_at': TimestampField, + 'error': fields.String, + 'stopped_at': TimestampField, + 'completed_segments': fields.Integer, + 'total_segments': fields.Integer, + } + + document_status_fields_list = { + 'data': fields.List(fields.Nested(document_status_fields)) + } + + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id): + dataset_id = str(dataset_id) + documents = db.session.query(Document).filter( + Document.dataset_id == dataset_id, + Document.tenant_id == current_user.current_tenant_id + ).all() + documents_status = [] + for document in documents: + completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != 're_segment').count() + total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), + DocumentSegment.status != 're_segment').count() + document.completed_segments = completed_segments + document.total_segments = total_segments + documents_status.append(marshal(document, self.document_status_fields)) + data = { + 'data': documents_status + } + return data + + api.add_resource(DatasetListApi, '/datasets') api.add_resource(DatasetApi, '/datasets/') api.add_resource(DatasetQueryApi, '/datasets//queries') -api.add_resource(DatasetIndexingEstimateApi, '/datasets/file-indexing-estimate') +api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate') api.add_resource(DatasetRelatedAppListApi, '/datasets//related-apps') +api.add_resource(DatasetIndexingStatusApi, '/datasets//indexing-status') diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 6888f0ed3..3e65b1319 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,6 +1,7 @@ # -*- coding:utf-8 -*- import random from datetime import datetime +from typing import List from flask import request from flask_login import login_required, current_user @@ -61,6 +62,29 @@ document_fields = { 'hit_count': fields.Integer, } +document_with_segments_fields = { + 'id': fields.String, + 'position': fields.Integer, + 'data_source_type': fields.String, + 'data_source_info': fields.Raw(attribute='data_source_info_dict'), + 'dataset_process_rule_id': fields.String, + 'name': fields.String, + 'created_from': fields.String, + 'created_by': fields.String, + 'created_at': TimestampField, + 'tokens': fields.Integer, + 'indexing_status': fields.String, + 'error': fields.String, + 'enabled': fields.Boolean, + 'disabled_at': TimestampField, + 'disabled_by': fields.String, + 'archived': fields.Boolean, + 'display_status': fields.String, + 'word_count': fields.Integer, + 'hit_count': fields.Integer, + 'completed_segments': fields.Integer, + 'total_segments': fields.Integer +} class DocumentResource(Resource): def get_document(self, dataset_id: str, document_id: str) -> Document: @@ -83,6 +107,23 @@ class DocumentResource(Resource): return document + def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]: + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + documents = DocumentService.get_batch_documents(dataset_id, batch) + + if not documents: + raise NotFound('Documents not found.') + + return documents + class GetProcessRuleApi(Resource): @setup_required @@ -132,9 +173,9 @@ class DatasetDocumentListApi(Resource): dataset_id = str(dataset_id) page = request.args.get('page', default=1, type=int) limit = request.args.get('limit', default=20, type=int) - search = request.args.get('search', default=None, type=str) + search = request.args.get('keyword', default=None, type=str) sort = request.args.get('sort', default='-created_at', type=str) - + fetch = request.args.get('fetch', default=False, type=bool) dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound('Dataset not found.') @@ -173,9 +214,20 @@ class DatasetDocumentListApi(Resource): paginated_documents = query.paginate( page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items - + if fetch: + for document in documents: + completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != 're_segment').count() + total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), + DocumentSegment.status != 're_segment').count() + document.completed_segments = completed_segments + document.total_segments = total_segments + data = marshal(documents, document_with_segments_fields) + else: + data = marshal(documents, document_fields) response = { - 'data': marshal(documents, document_fields), + 'data': data, 'has_more': len(documents) == limit, 'limit': limit, 'total': paginated_documents.total, @@ -184,10 +236,15 @@ class DatasetDocumentListApi(Resource): return response + documents_and_batch_fields = { + 'documents': fields.List(fields.Nested(document_fields)), + 'batch': fields.String + } + @setup_required @login_required @account_initialization_required - @marshal_with(document_fields) + @marshal_with(documents_and_batch_fields) def post(self, dataset_id): dataset_id = str(dataset_id) @@ -221,7 +278,7 @@ class DatasetDocumentListApi(Resource): DocumentService.document_create_args_validate(args) try: - document = DocumentService.save_document_with_dataset_id(dataset, args, current_user) + documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) except ProviderTokenNotInitError: raise ProviderNotInitializeError() except QuotaExceededError: @@ -229,13 +286,17 @@ class DatasetDocumentListApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return document + return { + 'documents': documents, + 'batch': batch + } class DatasetInitApi(Resource): dataset_and_document_fields = { 'dataset': fields.Nested(dataset_fields), - 'document': fields.Nested(document_fields) + 'documents': fields.List(fields.Nested(document_fields)), + 'batch': fields.String } @setup_required @@ -258,7 +319,7 @@ class DatasetInitApi(Resource): DocumentService.document_create_args_validate(args) try: - dataset, document = DocumentService.save_document_without_dataset_id( + dataset, documents, batch = DocumentService.save_document_without_dataset_id( tenant_id=current_user.current_tenant_id, document_data=args, account=current_user @@ -272,7 +333,8 @@ class DatasetInitApi(Resource): response = { 'dataset': dataset, - 'document': document + 'documents': documents, + 'batch': batch } return response @@ -317,11 +379,122 @@ class DocumentIndexingEstimateApi(DocumentResource): raise NotFound('File not found.') indexing_runner = IndexingRunner() - response = indexing_runner.indexing_estimate(file, data_process_rule_dict) + + response = indexing_runner.file_indexing_estimate([file], data_process_rule_dict) return response +class DocumentBatchIndexingEstimateApi(DocumentResource): + + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, batch): + dataset_id = str(dataset_id) + batch = str(batch) + dataset = DatasetService.get_dataset(dataset_id) + if dataset is None: + raise NotFound("Dataset not found.") + documents = self.get_batch_documents(dataset_id, batch) + response = { + "tokens": 0, + "total_price": 0, + "currency": "USD", + "total_segments": 0, + "preview": [] + } + if not documents: + return response + data_process_rule = documents[0].dataset_process_rule + data_process_rule_dict = data_process_rule.to_dict() + info_list = [] + for document in documents: + if document.indexing_status in ['completed', 'error']: + raise DocumentAlreadyFinishedError() + data_source_info = document.data_source_info_dict + # format document files info + if data_source_info and 'upload_file_id' in data_source_info: + file_id = data_source_info['upload_file_id'] + info_list.append(file_id) + # format document notion info + elif data_source_info and 'notion_workspace_id' in data_source_info and 'notion_page_id' in data_source_info: + pages = [] + page = { + 'page_id': data_source_info['notion_page_id'], + 'type': data_source_info['type'] + } + pages.append(page) + notion_info = { + 'workspace_id': data_source_info['notion_workspace_id'], + 'pages': pages + } + info_list.append(notion_info) + + if dataset.data_source_type == 'upload_file': + file_details = db.session.query(UploadFile).filter( + UploadFile.tenant_id == current_user.current_tenant_id, + UploadFile.id in info_list + ).all() + + if file_details is None: + raise NotFound("File not found.") + + indexing_runner = IndexingRunner() + response = indexing_runner.file_indexing_estimate(file_details, data_process_rule_dict) + elif dataset.data_source_type: + + indexing_runner = IndexingRunner() + response = indexing_runner.notion_indexing_estimate(info_list, + data_process_rule_dict) + else: + raise ValueError('Data source type not support') + return response + + +class DocumentBatchIndexingStatusApi(DocumentResource): + document_status_fields = { + 'id': fields.String, + 'indexing_status': fields.String, + 'processing_started_at': TimestampField, + 'parsing_completed_at': TimestampField, + 'cleaning_completed_at': TimestampField, + 'splitting_completed_at': TimestampField, + 'completed_at': TimestampField, + 'paused_at': TimestampField, + 'error': fields.String, + 'stopped_at': TimestampField, + 'completed_segments': fields.Integer, + 'total_segments': fields.Integer, + } + + document_status_fields_list = { + 'data': fields.List(fields.Nested(document_status_fields)) + } + + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, batch): + dataset_id = str(dataset_id) + batch = str(batch) + documents = self.get_batch_documents(dataset_id, batch) + documents_status = [] + for document in documents: + completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != 're_segment').count() + total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), + DocumentSegment.status != 're_segment').count() + document.completed_segments = completed_segments + document.total_segments = total_segments + documents_status.append(marshal(document, self.document_status_fields)) + data = { + 'data': documents_status + } + return data + + class DocumentIndexingStatusApi(DocumentResource): document_status_fields = { 'id': fields.String, @@ -408,7 +581,7 @@ class DocumentDetailApi(DocumentResource): 'disabled_by': document.disabled_by, 'archived': document.archived, 'segment_count': document.segment_count, - 'average_segment_length': document.average_segment_length, + 'average_segment_length': document.average_segment_length, 'hit_count': document.hit_count, 'display_status': document.display_status } @@ -428,7 +601,7 @@ class DocumentDetailApi(DocumentResource): 'created_at': document.created_at.timestamp(), 'tokens': document.tokens, 'indexing_status': document.indexing_status, - 'completed_at': int(document.completed_at.timestamp())if document.completed_at else None, + 'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None, 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None, 'indexing_latency': document.indexing_latency, 'error': document.error, @@ -579,6 +752,8 @@ class DocumentStatusApi(DocumentResource): return {'result': 'success'}, 200 elif action == "disable": + if not document.completed_at or document.indexing_status != 'completed': + raise InvalidActionError('Document is not completed.') if not document.enabled: raise InvalidActionError('Document already disabled.') @@ -678,6 +853,10 @@ api.add_resource(DatasetInitApi, '/datasets/init') api.add_resource(DocumentIndexingEstimateApi, '/datasets//documents//indexing-estimate') +api.add_resource(DocumentBatchIndexingEstimateApi, + '/datasets//batch//indexing-estimate') +api.add_resource(DocumentBatchIndexingStatusApi, + '/datasets//batch//indexing-status') api.add_resource(DocumentIndexingStatusApi, '/datasets//documents//indexing-status') api.add_resource(DocumentDetailApi, diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 47a90756d..3036882d7 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -69,12 +69,16 @@ class DocumentListApi(DatasetApiResource): document_data = { 'data_source': { 'type': 'upload_file', - 'info': upload_file.id + 'info': [ + { + 'upload_file_id': upload_file.id + } + ] } } try: - document = DocumentService.save_document_with_dataset_id( + documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, document_data=document_data, account=dataset.created_by_account, @@ -83,7 +87,7 @@ class DocumentListApi(DatasetApiResource): ) except ProviderTokenNotInitError: raise ProviderNotInitializeError() - + document = documents[0] if doc_type and doc_metadata: metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] diff --git a/api/core/data_source/notion.py b/api/core/data_source/notion.py new file mode 100644 index 000000000..4ed724f67 --- /dev/null +++ b/api/core/data_source/notion.py @@ -0,0 +1,367 @@ +"""Notion reader.""" +import json +import logging +import os +from datetime import datetime +from typing import Any, Dict, List, Optional + +import requests # type: ignore + +from llama_index.readers.base import BaseReader +from llama_index.readers.schema.base import Document + +INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN" +BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children" +DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query" +SEARCH_URL = "https://api.notion.com/v1/search" +RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" +RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" +HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] +logger = logging.getLogger(__name__) + + +# TODO: Notion DB reader coming soon! +class NotionPageReader(BaseReader): + """Notion Page reader. + + Reads a set of Notion pages. + + Args: + integration_token (str): Notion integration token. + + """ + + def __init__(self, integration_token: Optional[str] = None) -> None: + """Initialize with parameters.""" + if integration_token is None: + integration_token = os.getenv(INTEGRATION_TOKEN_NAME) + if integration_token is None: + raise ValueError( + "Must specify `integration_token` or set environment " + "variable `NOTION_INTEGRATION_TOKEN`." + ) + self.token = integration_token + self.headers = { + "Authorization": "Bearer " + self.token, + "Content-Type": "application/json", + "Notion-Version": "2022-06-28", + } + + def _read_block(self, block_id: str, num_tabs: int = 0) -> str: + """Read a block.""" + done = False + result_lines_arr = [] + cur_block_id = block_id + while not done: + block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) + query_dict: Dict[str, Any] = {} + + res = requests.request( + "GET", block_url, headers=self.headers, json=query_dict + ) + data = res.json() + if 'results' not in data or data["results"] is None: + done = True + break + heading = '' + for result in data["results"]: + result_type = result["type"] + result_obj = result[result_type] + cur_result_text_arr = [] + if result_type == 'table': + result_block_id = result["id"] + text = self._read_table_rows(result_block_id) + result_lines_arr.append(text) + else: + if "rich_text" in result_obj: + for rich_text in result_obj["rich_text"]: + # skip if doesn't have text object + if "text" in rich_text: + text = rich_text["text"]["content"] + prefix = "\t" * num_tabs + cur_result_text_arr.append(prefix + text) + if result_type in HEADING_TYPE: + heading = text + result_block_id = result["id"] + has_children = result["has_children"] + if has_children: + children_text = self._read_block( + result_block_id, num_tabs=num_tabs + 1 + ) + cur_result_text_arr.append(children_text) + + cur_result_text = "\n".join(cur_result_text_arr) + if result_type in HEADING_TYPE: + result_lines_arr.append(cur_result_text) + else: + result_lines_arr.append(f'{heading}\n{cur_result_text}') + + if data["next_cursor"] is None: + done = True + break + else: + cur_block_id = data["next_cursor"] + + result_lines = "\n".join(result_lines_arr) + return result_lines + + def _read_table_rows(self, block_id: str) -> str: + """Read table rows.""" + done = False + result_lines_arr = [] + cur_block_id = block_id + while not done: + block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) + query_dict: Dict[str, Any] = {} + + res = requests.request( + "GET", block_url, headers=self.headers, json=query_dict + ) + data = res.json() + # get table headers text + table_header_cell_texts = [] + tabel_header_cells = data["results"][0]['table_row']['cells'] + for tabel_header_cell in tabel_header_cells: + if tabel_header_cell: + for table_header_cell_text in tabel_header_cell: + text = table_header_cell_text["text"]["content"] + table_header_cell_texts.append(text) + # get table columns text and format + results = data["results"] + for i in range(len(results)-1): + column_texts = [] + tabel_column_cells = data["results"][i+1]['table_row']['cells'] + for j in range(len(tabel_column_cells)): + if tabel_column_cells[j]: + for table_column_cell_text in tabel_column_cells[j]: + column_text = table_column_cell_text["text"]["content"] + column_texts.append(f'{table_header_cell_texts[j]}:{column_text}') + + cur_result_text = "\n".join(column_texts) + result_lines_arr.append(cur_result_text) + + if data["next_cursor"] is None: + done = True + break + else: + cur_block_id = data["next_cursor"] + + result_lines = "\n".join(result_lines_arr) + return result_lines + def _read_parent_blocks(self, block_id: str, num_tabs: int = 0) -> List[str]: + """Read a block.""" + done = False + result_lines_arr = [] + cur_block_id = block_id + while not done: + block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) + query_dict: Dict[str, Any] = {} + + res = requests.request( + "GET", block_url, headers=self.headers, json=query_dict + ) + data = res.json() + # current block's heading + heading = '' + for result in data["results"]: + result_type = result["type"] + result_obj = result[result_type] + cur_result_text_arr = [] + if result_type == 'table': + result_block_id = result["id"] + text = self._read_table_rows(result_block_id) + text += "\n\n" + result_lines_arr.append(text) + else: + if "rich_text" in result_obj: + for rich_text in result_obj["rich_text"]: + # skip if doesn't have text object + if "text" in rich_text: + text = rich_text["text"]["content"] + cur_result_text_arr.append(text) + if result_type in HEADING_TYPE: + heading = text + + result_block_id = result["id"] + has_children = result["has_children"] + if has_children: + children_text = self._read_block( + result_block_id, num_tabs=num_tabs + 1 + ) + cur_result_text_arr.append(children_text) + + cur_result_text = "\n".join(cur_result_text_arr) + cur_result_text += "\n\n" + if result_type in HEADING_TYPE: + result_lines_arr.append(cur_result_text) + else: + result_lines_arr.append(f'{heading}\n{cur_result_text}') + + if data["next_cursor"] is None: + done = True + break + else: + cur_block_id = data["next_cursor"] + return result_lines_arr + + def read_page(self, page_id: str) -> str: + """Read a page.""" + return self._read_block(page_id) + + def read_page_as_documents(self, page_id: str) -> List[str]: + """Read a page as documents.""" + return self._read_parent_blocks(page_id) + + def query_database_data( + self, database_id: str, query_dict: Dict[str, Any] = {} + ) -> str: + """Get all the pages from a Notion database.""" + res = requests.post\ + ( + DATABASE_URL_TMPL.format(database_id=database_id), + headers=self.headers, + json=query_dict, + ) + data = res.json() + database_content_list = [] + if 'results' not in data or data["results"] is None: + return "" + for result in data["results"]: + properties = result['properties'] + data = {} + for property_name, property_value in properties.items(): + type = property_value['type'] + if type == 'multi_select': + value = [] + multi_select_list = property_value[type] + for multi_select in multi_select_list: + value.append(multi_select['name']) + elif type == 'rich_text' or type == 'title': + if len(property_value[type]) > 0: + value = property_value[type][0]['plain_text'] + else: + value = '' + elif type == 'select' or type == 'status': + if property_value[type]: + value = property_value[type]['name'] + else: + value = '' + else: + value = property_value[type] + data[property_name] = value + database_content_list.append(json.dumps(data)) + + return "\n\n".join(database_content_list) + + def query_database( + self, database_id: str, query_dict: Dict[str, Any] = {} + ) -> List[str]: + """Get all the pages from a Notion database.""" + res = requests.post\ + ( + DATABASE_URL_TMPL.format(database_id=database_id), + headers=self.headers, + json=query_dict, + ) + data = res.json() + page_ids = [] + for result in data["results"]: + page_id = result["id"] + page_ids.append(page_id) + + return page_ids + + def search(self, query: str) -> List[str]: + """Search Notion page given a text query.""" + done = False + next_cursor: Optional[str] = None + page_ids = [] + while not done: + query_dict = { + "query": query, + } + if next_cursor is not None: + query_dict["start_cursor"] = next_cursor + res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict) + data = res.json() + for result in data["results"]: + page_id = result["id"] + page_ids.append(page_id) + + if data["next_cursor"] is None: + done = True + break + else: + next_cursor = data["next_cursor"] + return page_ids + + def load_data( + self, page_ids: List[str] = [], database_id: Optional[str] = None + ) -> List[Document]: + """Load data from the input directory. + + Args: + page_ids (List[str]): List of page ids to load. + + Returns: + List[Document]: List of documents. + + """ + if not page_ids and not database_id: + raise ValueError("Must specify either `page_ids` or `database_id`.") + docs = [] + if database_id is not None: + # get all the pages in the database + page_ids = self.query_database(database_id) + for page_id in page_ids: + page_text = self.read_page(page_id) + docs.append(Document(page_text)) + else: + for page_id in page_ids: + page_text = self.read_page(page_id) + docs.append(Document(page_text)) + + return docs + + def load_data_as_documents( + self, page_ids: List[str] = [], database_id: Optional[str] = None + ) -> List[Document]: + if not page_ids and not database_id: + raise ValueError("Must specify either `page_ids` or `database_id`.") + docs = [] + if database_id is not None: + # get all the pages in the database + page_text = self.query_database_data(database_id) + docs.append(Document(page_text)) + else: + for page_id in page_ids: + page_text_list = self.read_page_as_documents(page_id) + for page_text in page_text_list: + docs.append(Document(page_text)) + + return docs + + def get_page_last_edited_time(self, page_id: str) -> str: + retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=page_id) + query_dict: Dict[str, Any] = {} + + res = requests.request( + "GET", retrieve_page_url, headers=self.headers, json=query_dict + ) + data = res.json() + return data["last_edited_time"] + + def get_database_last_edited_time(self, database_id: str) -> str: + retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=database_id) + query_dict: Dict[str, Any] = {} + + res = requests.request( + "GET", retrieve_page_url, headers=self.headers, json=query_dict + ) + data = res.json() + return data["last_edited_time"] + + +if __name__ == "__main__": + reader = NotionPageReader() + logger.info(reader.search("What I")) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 17ae53a90..319a6bad1 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -5,6 +5,8 @@ import tempfile import time from pathlib import Path from typing import Optional, List + +from flask_login import current_user from langchain.text_splitter import RecursiveCharacterTextSplitter from llama_index import SimpleDirectoryReader @@ -13,6 +15,8 @@ from llama_index.data_structs.node_v2 import DocumentRelationship from llama_index.node_parser import SimpleNodeParser, NodeParser from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR from llama_index.readers.file.markdown_parser import MarkdownParser + +from core.data_source.notion import NotionPageReader from core.index.readers.xlsx_parser import XLSXParser from core.docstore.dataset_docstore import DatesetDocumentStore from core.index.keyword_table_index import KeywordTableIndex @@ -27,6 +31,7 @@ from extensions.ext_redis import redis_client from extensions.ext_storage import storage from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule from models.model import UploadFile +from models.source import DataSourceBinding class IndexingRunner: @@ -35,42 +40,43 @@ class IndexingRunner: self.storage = storage self.embedding_model_name = embedding_model_name - def run(self, document: Document): + def run(self, documents: List[Document]): """Run the indexing process.""" - # get dataset - dataset = Dataset.query.filter_by( - id=document.dataset_id - ).first() + for document in documents: + # get dataset + dataset = Dataset.query.filter_by( + id=document.dataset_id + ).first() - if not dataset: - raise ValueError("no dataset found") + if not dataset: + raise ValueError("no dataset found") - # load file - text_docs = self._load_data(document) + # load file + text_docs = self._load_data(document) - # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ - first() + # get the process rule + processing_rule = db.session.query(DatasetProcessRule). \ + filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ + first() - # get node parser for splitting - node_parser = self._get_node_parser(processing_rule) + # get node parser for splitting + node_parser = self._get_node_parser(processing_rule) - # split to nodes - nodes = self._step_split( - text_docs=text_docs, - node_parser=node_parser, - dataset=dataset, - document=document, - processing_rule=processing_rule - ) + # split to nodes + nodes = self._step_split( + text_docs=text_docs, + node_parser=node_parser, + dataset=dataset, + document=document, + processing_rule=processing_rule + ) - # build index - self._build_index( - dataset=dataset, - document=document, - nodes=nodes - ) + # build index + self._build_index( + dataset=dataset, + document=document, + nodes=nodes + ) def run_in_splitting_status(self, document: Document): """Run the indexing process when the index_status is splitting.""" @@ -164,38 +170,98 @@ class IndexingRunner: nodes=nodes ) - def indexing_estimate(self, file_detail: UploadFile, tmp_processing_rule: dict) -> dict: + def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict: """ Estimate the indexing for the document. """ - # load data from file - text_docs = self._load_data_from_file(file_detail) - - processing_rule = DatasetProcessRule( - mode=tmp_processing_rule["mode"], - rules=json.dumps(tmp_processing_rule["rules"]) - ) - - # get node parser for splitting - node_parser = self._get_node_parser(processing_rule) - - # split to nodes - nodes = self._split_to_nodes( - text_docs=text_docs, - node_parser=node_parser, - processing_rule=processing_rule - ) - tokens = 0 preview_texts = [] - for node in nodes: - if len(preview_texts) < 5: - preview_texts.append(node.get_text()) + total_segments = 0 + for file_detail in file_details: + # load data from file + text_docs = self._load_data_from_file(file_detail) - tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) + processing_rule = DatasetProcessRule( + mode=tmp_processing_rule["mode"], + rules=json.dumps(tmp_processing_rule["rules"]) + ) + + # get node parser for splitting + node_parser = self._get_node_parser(processing_rule) + + # split to nodes + nodes = self._split_to_nodes( + text_docs=text_docs, + node_parser=node_parser, + processing_rule=processing_rule + ) + total_segments += len(nodes) + for node in nodes: + if len(preview_texts) < 5: + preview_texts.append(node.get_text()) + + tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) return { - "total_segments": len(nodes), + "total_segments": total_segments, + "tokens": tokens, + "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), + "currency": TokenCalculator.get_currency(self.embedding_model_name), + "preview": preview_texts + } + + def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict) -> dict: + """ + Estimate the indexing for the document. + """ + # load data from notion + tokens = 0 + preview_texts = [] + total_segments = 0 + for notion_info in notion_info_list: + workspace_id = notion_info['workspace_id'] + data_source_binding = DataSourceBinding.query.filter( + db.and_( + DataSourceBinding.tenant_id == current_user.current_tenant_id, + DataSourceBinding.provider == 'notion', + DataSourceBinding.disabled == False, + DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' + ) + ).first() + if not data_source_binding: + raise ValueError('Data source binding not found.') + reader = NotionPageReader(integration_token=data_source_binding.access_token) + for page in notion_info['pages']: + if page['type'] == 'page': + page_ids = [page['page_id']] + documents = reader.load_data_as_documents(page_ids=page_ids) + elif page['type'] == 'database': + documents = reader.load_data_as_documents(database_id=page['page_id']) + else: + documents = [] + processing_rule = DatasetProcessRule( + mode=tmp_processing_rule["mode"], + rules=json.dumps(tmp_processing_rule["rules"]) + ) + + # get node parser for splitting + node_parser = self._get_node_parser(processing_rule) + + # split to nodes + nodes = self._split_to_nodes( + text_docs=documents, + node_parser=node_parser, + processing_rule=processing_rule + ) + total_segments += len(nodes) + for node in nodes: + if len(preview_texts) < 5: + preview_texts.append(node.get_text()) + + tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) + + return { + "total_segments": total_segments, "tokens": tokens, "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), "currency": TokenCalculator.get_currency(self.embedding_model_name), @@ -204,25 +270,50 @@ class IndexingRunner: def _load_data(self, document: Document) -> List[Document]: # load file - if document.data_source_type != "upload_file": + if document.data_source_type not in ["upload_file", "notion_import"]: return [] data_source_info = document.data_source_info_dict - if not data_source_info or 'upload_file_id' not in data_source_info: - raise ValueError("no upload file found") + text_docs = [] + if document.data_source_type == 'upload_file': + if not data_source_info or 'upload_file_id' not in data_source_info: + raise ValueError("no upload file found") - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == data_source_info['upload_file_id']). \ - one_or_none() - - text_docs = self._load_data_from_file(file_detail) + file_detail = db.session.query(UploadFile). \ + filter(UploadFile.id == data_source_info['upload_file_id']). \ + one_or_none() + text_docs = self._load_data_from_file(file_detail) + elif document.data_source_type == 'notion_import': + if not data_source_info or 'notion_page_id' not in data_source_info \ + or 'notion_workspace_id' not in data_source_info: + raise ValueError("no notion page found") + workspace_id = data_source_info['notion_workspace_id'] + page_id = data_source_info['notion_page_id'] + page_type = data_source_info['type'] + data_source_binding = DataSourceBinding.query.filter( + db.and_( + DataSourceBinding.tenant_id == document.tenant_id, + DataSourceBinding.provider == 'notion', + DataSourceBinding.disabled == False, + DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' + ) + ).first() + if not data_source_binding: + raise ValueError('Data source binding not found.') + if page_type == 'page': + # add page last_edited_time to data_source_info + self._get_notion_page_last_edited_time(page_id, data_source_binding.access_token, document) + text_docs = self._load_page_data_from_notion(page_id, data_source_binding.access_token) + elif page_type == 'database': + # add page last_edited_time to data_source_info + self._get_notion_database_last_edited_time(page_id, data_source_binding.access_token, document) + text_docs = self._load_database_data_from_notion(page_id, data_source_binding.access_token) # update document status to splitting self._update_document_index_status( document_id=document.id, after_indexing_status="splitting", extra_update_params={ - Document.file_id: file_detail.id, Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]), Document.parsing_completed_at: datetime.datetime.utcnow() } @@ -259,6 +350,41 @@ class IndexingRunner: return text_docs + def _load_page_data_from_notion(self, page_id: str, access_token: str) -> List[Document]: + page_ids = [page_id] + reader = NotionPageReader(integration_token=access_token) + text_docs = reader.load_data_as_documents(page_ids=page_ids) + return text_docs + + def _load_database_data_from_notion(self, database_id: str, access_token: str) -> List[Document]: + reader = NotionPageReader(integration_token=access_token) + text_docs = reader.load_data_as_documents(database_id=database_id) + return text_docs + + def _get_notion_page_last_edited_time(self, page_id: str, access_token: str, document: Document): + reader = NotionPageReader(integration_token=access_token) + last_edited_time = reader.get_page_last_edited_time(page_id) + data_source_info = document.data_source_info_dict + data_source_info['last_edited_time'] = last_edited_time + update_params = { + Document.data_source_info: json.dumps(data_source_info) + } + + Document.query.filter_by(id=document.id).update(update_params) + db.session.commit() + + def _get_notion_database_last_edited_time(self, page_id: str, access_token: str, document: Document): + reader = NotionPageReader(integration_token=access_token) + last_edited_time = reader.get_database_last_edited_time(page_id) + data_source_info = document.data_source_info_dict + data_source_info['last_edited_time'] = last_edited_time + update_params = { + Document.data_source_info: json.dumps(data_source_info) + } + + Document.query.filter_by(id=document.id).update(update_params) + db.session.commit() + def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser: """ Get the NodeParser object according to the processing rule. @@ -308,7 +434,7 @@ class IndexingRunner: embedding_model_name=self.embedding_model_name, document_id=document.id ) - + # add document segments doc_store.add_documents(nodes) # update document status to indexing diff --git a/api/libs/oauth.py b/api/libs/oauth.py index ce41f0c22..c89ac6d65 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,7 +1,12 @@ +import json import urllib.parse from dataclasses import dataclass import requests +from flask_login import current_user + +from extensions.ext_database import db +from models.source import DataSourceBinding @dataclass @@ -134,3 +139,5 @@ class GoogleOAuth(OAuth): name=None, email=raw_info['email'] ) + + diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py new file mode 100644 index 000000000..5ae53c59f --- /dev/null +++ b/api/libs/oauth_data_source.py @@ -0,0 +1,256 @@ +import json +import urllib.parse + +import requests +from flask_login import current_user + +from extensions.ext_database import db +from models.source import DataSourceBinding + + +class OAuthDataSource: + def __init__(self, client_id: str, client_secret: str, redirect_uri: str): + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + + def get_authorization_url(self): + raise NotImplementedError() + + def get_access_token(self, code: str): + raise NotImplementedError() + + +class NotionOAuth(OAuthDataSource): + _AUTH_URL = 'https://api.notion.com/v1/oauth/authorize' + _TOKEN_URL = 'https://api.notion.com/v1/oauth/token' + _NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search" + _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks" + + def get_authorization_url(self): + params = { + 'client_id': self.client_id, + 'response_type': 'code', + 'redirect_uri': self.redirect_uri, + 'owner': 'user' + } + return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + + def get_access_token(self, code: str): + data = { + 'code': code, + 'grant_type': 'authorization_code', + 'redirect_uri': self.redirect_uri + } + headers = {'Accept': 'application/json'} + auth = (self.client_id, self.client_secret) + response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) + + response_json = response.json() + access_token = response_json.get('access_token') + if not access_token: + raise ValueError(f"Error in Notion OAuth: {response_json}") + workspace_name = response_json.get('workspace_name') + workspace_icon = response_json.get('workspace_icon') + workspace_id = response_json.get('workspace_id') + # get all authorized pages + pages = self.get_authorized_pages(access_token) + source_info = { + 'workspace_name': workspace_name, + 'workspace_icon': workspace_icon, + 'workspace_id': workspace_id, + 'pages': pages, + 'total': len(pages) + } + # save data source binding + data_source_binding = DataSourceBinding.query.filter( + db.and_( + DataSourceBinding.tenant_id == current_user.current_tenant_id, + DataSourceBinding.provider == 'notion', + DataSourceBinding.access_token == access_token + ) + ).first() + if data_source_binding: + data_source_binding.source_info = source_info + data_source_binding.disabled = False + db.session.commit() + else: + new_data_source_binding = DataSourceBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=source_info, + provider='notion' + ) + db.session.add(new_data_source_binding) + db.session.commit() + + def sync_data_source(self, binding_id: str): + # save data source binding + data_source_binding = DataSourceBinding.query.filter( + db.and_( + DataSourceBinding.tenant_id == current_user.current_tenant_id, + DataSourceBinding.provider == 'notion', + DataSourceBinding.id == binding_id, + DataSourceBinding.disabled == False + ) + ).first() + if data_source_binding: + # get all authorized pages + pages = self.get_authorized_pages(data_source_binding.access_token) + source_info = data_source_binding.source_info + new_source_info = { + 'workspace_name': source_info['workspace_name'], + 'workspace_icon': source_info['workspace_icon'], + 'workspace_id': source_info['workspace_id'], + 'pages': pages, + 'total': len(pages) + } + data_source_binding.source_info = new_source_info + data_source_binding.disabled = False + db.session.commit() + else: + raise ValueError('Data source binding not found') + + def get_authorized_pages(self, access_token: str): + pages = [] + page_results = self.notion_page_search(access_token) + database_results = self.notion_database_search(access_token) + # get page detail + for page_result in page_results: + page_id = page_result['id'] + if 'Name' in page_result['properties']: + if len(page_result['properties']['Name']['title']) > 0: + page_name = page_result['properties']['Name']['title'][0]['plain_text'] + else: + page_name = 'Untitled' + elif 'title' in page_result['properties']: + if len(page_result['properties']['title']['title']) > 0: + page_name = page_result['properties']['title']['title'][0]['plain_text'] + else: + page_name = 'Untitled' + elif 'Title' in page_result['properties']: + if len(page_result['properties']['Title']['title']) > 0: + page_name = page_result['properties']['Title']['title'][0]['plain_text'] + else: + page_name = 'Untitled' + else: + page_name = 'Untitled' + page_icon = page_result['icon'] + if page_icon: + icon_type = page_icon['type'] + if icon_type == 'external' or icon_type == 'file': + url = page_icon[icon_type]['url'] + icon = { + 'type': 'url', + 'url': url if url.startswith('http') else f'https://www.notion.so{url}' + } + else: + icon = { + 'type': 'emoji', + 'emoji': page_icon[icon_type] + } + else: + icon = None + parent = page_result['parent'] + parent_type = parent['type'] + if parent_type == 'block_id': + parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) + elif parent_type == 'workspace': + parent_id = 'root' + else: + parent_id = parent[parent_type] + page = { + 'page_id': page_id, + 'page_name': page_name, + 'page_icon': icon, + 'parent_id': parent_id, + 'type': 'page' + } + pages.append(page) + # get database detail + for database_result in database_results: + page_id = database_result['id'] + if len(database_result['title']) > 0: + page_name = database_result['title'][0]['plain_text'] + else: + page_name = 'Untitled' + page_icon = database_result['icon'] + if page_icon: + icon_type = page_icon['type'] + if icon_type == 'external' or icon_type == 'file': + url = page_icon[icon_type]['url'] + icon = { + 'type': 'url', + 'url': url if url.startswith('http') else f'https://www.notion.so{url}' + } + else: + icon = { + 'type': icon_type, + icon_type: page_icon[icon_type] + } + else: + icon = None + parent = database_result['parent'] + parent_type = parent['type'] + if parent_type == 'block_id': + parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) + elif parent_type == 'workspace': + parent_id = 'root' + else: + parent_id = parent[parent_type] + page = { + 'page_id': page_id, + 'page_name': page_name, + 'page_icon': icon, + 'parent_id': parent_id, + 'type': 'database' + } + pages.append(page) + return pages + + def notion_page_search(self, access_token: str): + data = { + 'filter': { + "value": "page", + "property": "object" + } + } + headers = { + 'Content-Type': 'application/json', + 'Authorization': f"Bearer {access_token}", + 'Notion-Version': '2022-06-28', + } + response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) + response_json = response.json() + results = response_json['results'] + return results + + def notion_block_parent_page_id(self, access_token: str, block_id: str): + headers = { + 'Authorization': f"Bearer {access_token}", + 'Notion-Version': '2022-06-28', + } + response = requests.get(url=f'{self._NOTION_BLOCK_SEARCH}/{block_id}', headers=headers) + response_json = response.json() + parent = response_json['parent'] + parent_type = parent['type'] + if parent_type == 'block_id': + return self.notion_block_parent_page_id(access_token, parent[parent_type]) + return parent[parent_type] + + def notion_database_search(self, access_token: str): + data = { + 'filter': { + "value": "database", + "property": "object" + } + } + headers = { + 'Content-Type': 'application/json', + 'Authorization': f"Bearer {access_token}", + 'Notion-Version': '2022-06-28', + } + response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) + response_json = response.json() + results = response_json['results'] + return results diff --git a/api/migrations/versions/e32f6ccb87c6_e08af0a69ccefbb59fa80c778efee300bb780980.py b/api/migrations/versions/e32f6ccb87c6_e08af0a69ccefbb59fa80c778efee300bb780980.py new file mode 100644 index 000000000..444f224fc --- /dev/null +++ b/api/migrations/versions/e32f6ccb87c6_e08af0a69ccefbb59fa80c778efee300bb780980.py @@ -0,0 +1,46 @@ +"""e08af0a69ccefbb59fa80c778efee300bb780980 + +Revision ID: e32f6ccb87c6 +Revises: a45f4dfde53b +Create Date: 2023-06-06 19:58:33.103819 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'e32f6ccb87c6' +down_revision = '614f77cecc48' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('data_source_bindings', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('access_token', sa.String(length=255), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('source_info', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('disabled', sa.Boolean(), server_default=sa.text('false'), nullable=True), + sa.PrimaryKeyConstraint('id', name='source_binding_pkey') + ) + with op.batch_alter_table('data_source_bindings', schema=None) as batch_op: + batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False) + batch_op.create_index('source_info_idx', ['source_info'], unique=False, postgresql_using='gin') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('data_source_bindings', schema=None) as batch_op: + batch_op.drop_index('source_info_idx', postgresql_using='gin') + batch_op.drop_index('source_binding_tenant_id_idx') + + op.drop_table('data_source_bindings') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 29588c1f3..bbc5340bc 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -190,7 +190,7 @@ class Document(db.Model): doc_type = db.Column(db.String(40), nullable=True) doc_metadata = db.Column(db.JSON, nullable=True) - DATA_SOURCES = ['upload_file'] + DATA_SOURCES = ['upload_file', 'notion_import'] @property def display_status(self): @@ -242,6 +242,8 @@ class Document(db.Model): 'created_at': file_detail.created_at.timestamp() } } + elif self.data_source_type == 'notion_import': + return json.loads(self.data_source_info) return {} @property diff --git a/api/models/source.py b/api/models/source.py new file mode 100644 index 000000000..c7c04075b --- /dev/null +++ b/api/models/source.py @@ -0,0 +1,21 @@ +from sqlalchemy.dialects.postgresql import UUID + +from extensions.ext_database import db +from sqlalchemy.dialects.postgresql import JSONB + +class DataSourceBinding(db.Model): + __tablename__ = 'data_source_bindings' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='source_binding_pkey'), + db.Index('source_binding_tenant_id_idx', 'tenant_id'), + db.Index('source_info_idx', "source_info", postgresql_using='gin') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + access_token = db.Column(db.String(255), nullable=False) + provider = db.Column(db.String(255), nullable=False) + source_info = db.Column(JSONB, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 9a03a6338..2619cdcbe 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -3,7 +3,7 @@ import logging import datetime import time import random -from typing import Optional +from typing import Optional, List from extensions.ext_redis import redis_client from flask_login import current_user @@ -14,10 +14,12 @@ from extensions.ext_database import db from models.account import Account from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment from models.model import UploadFile +from models.source import DataSourceBinding from services.errors.account import NoPermissionError from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError +from tasks.clean_notion_document_task import clean_notion_document_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_update_task import document_indexing_update_task @@ -286,6 +288,24 @@ class DocumentService: return document @staticmethod + def get_document_by_dataset_id(dataset_id: str) -> List[Document]: + documents = db.session.query(Document).filter( + Document.dataset_id == dataset_id, + Document.enabled == True + ).all() + + return documents + + @staticmethod + def get_batch_documents(dataset_id: str, batch: str) -> List[Document]: + documents = db.session.query(Document).filter( + Document.batch == batch, + Document.dataset_id == dataset_id, + Document.tenant_id == current_user.current_tenant_id + ).all() + + return documents + @staticmethod def get_document_file_detail(file_id: str): file_detail = db.session.query(UploadFile). \ filter(UploadFile.id == file_id). \ @@ -344,9 +364,9 @@ class DocumentService: @staticmethod def get_documents_position(dataset_id): - documents = Document.query.filter_by(dataset_id=dataset_id).all() - if documents: - return len(documents) + 1 + document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() + if document: + return document.position + 1 else: return 1 @@ -363,9 +383,11 @@ class DocumentService: if dataset.indexing_technique == 'high_quality': IndexBuilder.get_default_service_context(dataset.tenant_id) - + documents = [] + batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) if 'original_document_id' in document_data and document_data["original_document_id"]: document = DocumentService.update_document_with_dataset_id(dataset, document_data, account) + documents.append(document) else: # save process rule if not dataset_process_rule: @@ -386,46 +408,114 @@ class DocumentService: ) db.session.add(dataset_process_rule) db.session.commit() - - file_name = '' - data_source_info = {} - if document_data["data_source"]["type"] == "upload_file": - file_id = document_data["data_source"]["info"] - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id == file_id - ).first() - - # raise error if file not found - if not file: - raise FileNotExistsError() - - file_name = file.name - data_source_info = { - "upload_file_id": file_id, - } - - # save document position = DocumentService.get_documents_position(dataset.id) - document = Document( - tenant_id=dataset.tenant_id, - dataset_id=dataset.id, - position=position, - data_source_type=document_data["data_source"]["type"], - data_source_info=json.dumps(data_source_info), - dataset_process_rule_id=dataset_process_rule.id, - batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)), - name=file_name, - created_from=created_from, - created_by=account.id, - # created_api_request_id = db.Column(UUID, nullable=True) - ) + document_ids = [] + if document_data["data_source"]["type"] == "upload_file": + upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + for file_id in upload_file_list: + file = db.session.query(UploadFile).filter( + UploadFile.tenant_id == dataset.tenant_id, + UploadFile.id == file_id + ).first() - db.session.add(document) + # raise error if file not found + if not file: + raise FileNotExistsError() + + file_name = file.name + data_source_info = { + "upload_file_id": file_id, + } + document = DocumentService.save_document(dataset, dataset_process_rule.id, + document_data["data_source"]["type"], + data_source_info, created_from, position, + account, file_name, batch) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + elif document_data["data_source"]["type"] == "notion_import": + notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + exist_page_ids = [] + exist_document = dict() + documents = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type='notion_import', + enabled=True + ).all() + if documents: + for document in documents: + data_source_info = json.loads(document.data_source_info) + exist_page_ids.append(data_source_info['notion_page_id']) + exist_document[data_source_info['notion_page_id']] = document.id + for notion_info in notion_info_list: + workspace_id = notion_info['workspace_id'] + data_source_binding = DataSourceBinding.query.filter( + db.and_( + DataSourceBinding.tenant_id == current_user.current_tenant_id, + DataSourceBinding.provider == 'notion', + DataSourceBinding.disabled == False, + DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' + ) + ).first() + if not data_source_binding: + raise ValueError('Data source binding not found.') + for page in notion_info['pages']: + if page['page_id'] not in exist_page_ids: + data_source_info = { + "notion_workspace_id": workspace_id, + "notion_page_id": page['page_id'], + "notion_page_icon": page['page_icon'], + "type": page['type'] + } + document = DocumentService.save_document(dataset, dataset_process_rule.id, + document_data["data_source"]["type"], + data_source_info, created_from, position, + account, page['page_name'], batch) + # if page['type'] == 'database': + # document.splitting_completed_at = datetime.datetime.utcnow() + # document.cleaning_completed_at = datetime.datetime.utcnow() + # document.parsing_completed_at = datetime.datetime.utcnow() + # document.completed_at = datetime.datetime.utcnow() + # document.indexing_status = 'completed' + # document.word_count = 0 + # document.tokens = 0 + # document.indexing_latency = 0 + db.session.add(document) + db.session.flush() + # if page['type'] != 'database': + document_ids.append(document.id) + documents.append(document) + position += 1 + else: + exist_document.pop(page['page_id']) + # delete not selected documents + if len(exist_document) > 0: + clean_notion_document_task.delay(list(exist_document.values()), dataset.id) db.session.commit() # trigger async task - document_indexing_task.delay(document.dataset_id, document.id) + document_indexing_task.delay(dataset.id, document_ids) + + return documents, batch + + @staticmethod + def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict, + created_from: str, position: int, account: Account, name: str, batch: str): + document = Document( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=position, + data_source_type=data_source_type, + data_source_info=json.dumps(data_source_info), + dataset_process_rule_id=process_rule_id, + batch=batch, + name=name, + created_from=created_from, + created_by=account.id, + ) return document @staticmethod @@ -460,20 +550,42 @@ class DocumentService: file_name = '' data_source_info = {} if document_data["data_source"]["type"] == "upload_file": - file_id = document_data["data_source"]["info"] - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id == file_id - ).first() + upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + for file_id in upload_file_list: + file = db.session.query(UploadFile).filter( + UploadFile.tenant_id == dataset.tenant_id, + UploadFile.id == file_id + ).first() - # raise error if file not found - if not file: - raise FileNotExistsError() + # raise error if file not found + if not file: + raise FileNotExistsError() - file_name = file.name - data_source_info = { - "upload_file_id": file_id, - } + file_name = file.name + data_source_info = { + "upload_file_id": file_id, + } + elif document_data["data_source"]["type"] == "notion_import": + notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + for notion_info in notion_info_list: + workspace_id = notion_info['workspace_id'] + data_source_binding = DataSourceBinding.query.filter( + db.and_( + DataSourceBinding.tenant_id == current_user.current_tenant_id, + DataSourceBinding.provider == 'notion', + DataSourceBinding.disabled == False, + DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' + ) + ).first() + if not data_source_binding: + raise ValueError('Data source binding not found.') + for page in notion_info['pages']: + data_source_info = { + "notion_workspace_id": workspace_id, + "notion_page_id": page['page_id'], + "notion_page_icon": page['page_icon'], + "type": page['type'] + } document.data_source_type = document_data["data_source"]["type"] document.data_source_info = json.dumps(data_source_info) document.name = file_name @@ -513,15 +625,15 @@ class DocumentService: db.session.add(dataset) db.session.flush() - document = DocumentService.save_document_with_dataset_id(dataset, document_data, account) + documents, batch = DocumentService.save_document_with_dataset_id(dataset, document_data, account) cut_length = 18 - cut_name = document.name[:cut_length] - dataset.name = cut_name + '...' if len(document.name) > cut_length else cut_name - dataset.description = 'useful for when you want to answer queries about the ' + document.name + cut_name = documents[0].name[:cut_length] + dataset.name = cut_name + '...' + dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name db.session.commit() - return dataset, document + return dataset, documents, batch @classmethod def document_create_args_validate(cls, args: dict): @@ -552,9 +664,15 @@ class DocumentService: if args['data_source']['type'] not in Document.DATA_SOURCES: raise ValueError("Data source type is invalid") + if 'info_list' not in args['data_source'] or not args['data_source']['info_list']: + raise ValueError("Data source info is required") + if args['data_source']['type'] == 'upload_file': - if 'info' not in args['data_source'] or not args['data_source']['info']: - raise ValueError("Data source info is required") + if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['file_info_list']: + raise ValueError("File source info is required") + if args['data_source']['type'] == 'notion_import': + if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['notion_info_list']: + raise ValueError("Notion source info is required") @classmethod def process_rule_args_validate(cls, args: dict): @@ -624,3 +742,78 @@ class DocumentService: if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): raise ValueError("Process rule segmentation max_tokens is invalid") + + @classmethod + def estimate_args_validate(cls, args: dict): + if 'info_list' not in args or not args['info_list']: + raise ValueError("Data source info is required") + + if not isinstance(args['info_list'], dict): + raise ValueError("Data info is invalid") + + if 'process_rule' not in args or not args['process_rule']: + raise ValueError("Process rule is required") + + if not isinstance(args['process_rule'], dict): + raise ValueError("Process rule is invalid") + + if 'mode' not in args['process_rule'] or not args['process_rule']['mode']: + raise ValueError("Process rule mode is required") + + if args['process_rule']['mode'] not in DatasetProcessRule.MODES: + raise ValueError("Process rule mode is invalid") + + if args['process_rule']['mode'] == 'automatic': + args['process_rule']['rules'] = {} + else: + if 'rules' not in args['process_rule'] or not args['process_rule']['rules']: + raise ValueError("Process rule rules is required") + + if not isinstance(args['process_rule']['rules'], dict): + raise ValueError("Process rule rules is invalid") + + if 'pre_processing_rules' not in args['process_rule']['rules'] \ + or args['process_rule']['rules']['pre_processing_rules'] is None: + raise ValueError("Process rule pre_processing_rules is required") + + if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list): + raise ValueError("Process rule pre_processing_rules is invalid") + + unique_pre_processing_rule_dicts = {} + for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']: + if 'id' not in pre_processing_rule or not pre_processing_rule['id']: + raise ValueError("Process rule pre_processing_rules id is required") + + if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES: + raise ValueError("Process rule pre_processing_rules id is invalid") + + if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None: + raise ValueError("Process rule pre_processing_rules enabled is required") + + if not isinstance(pre_processing_rule['enabled'], bool): + raise ValueError("Process rule pre_processing_rules enabled is invalid") + + unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule + + args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values()) + + if 'segmentation' not in args['process_rule']['rules'] \ + or args['process_rule']['rules']['segmentation'] is None: + raise ValueError("Process rule segmentation is required") + + if not isinstance(args['process_rule']['rules']['segmentation'], dict): + raise ValueError("Process rule segmentation is invalid") + + if 'separator' not in args['process_rule']['rules']['segmentation'] \ + or not args['process_rule']['rules']['segmentation']['separator']: + raise ValueError("Process rule segmentation separator is required") + + if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str): + raise ValueError("Process rule segmentation separator is invalid") + + if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \ + or not args['process_rule']['rules']['segmentation']['max_tokens']: + raise ValueError("Process rule segmentation max_tokens is required") + + if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): + raise ValueError("Process rule segmentation max_tokens is invalid") diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py new file mode 100644 index 000000000..3fdc03884 --- /dev/null +++ b/api/tasks/clean_notion_document_task.py @@ -0,0 +1,58 @@ +import logging +import time +from typing import List + +import click +from celery import shared_task + +from core.index.keyword_table_index import KeywordTableIndex +from core.index.vector_index import VectorIndex +from extensions.ext_database import db +from models.dataset import DocumentSegment, Dataset, Document + + +@shared_task +def clean_notion_document_task(document_ids: List[str], dataset_id: str): + """ + Clean document when document deleted. + :param document_ids: document ids + :param dataset_id: dataset id + + Usage: clean_notion_document_task.delay(document_ids, dataset_id) + """ + logging.info(click.style('Start clean document when import form notion document deleted: {}'.format(dataset_id), fg='green')) + start_at = time.perf_counter() + + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + + if not dataset: + raise Exception('Document has no dataset') + + vector_index = VectorIndex(dataset=dataset) + keyword_table_index = KeywordTableIndex(dataset=dataset) + for document_id in document_ids: + document = db.session.query(Document).filter( + Document.id == document_id + ).first() + db.session.delete(document) + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + index_node_ids = [segment.index_node_id for segment in segments] + + # delete from vector index + vector_index.del_nodes(index_node_ids) + + # delete from keyword index + if index_node_ids: + keyword_table_index.del_nodes(index_node_ids) + + for segment in segments: + db.session.delete(segment) + db.session.commit() + end_at = time.perf_counter() + logging.info( + click.style('Clean document when import form notion document deleted end :: {} latency: {}'.format( + dataset_id, end_at - start_at), + fg='green')) + except Exception: + logging.exception("Cleaned document when import form notion document deleted failed") diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py new file mode 100644 index 000000000..56869428d --- /dev/null +++ b/api/tasks/document_indexing_sync_task.py @@ -0,0 +1,109 @@ +import datetime +import logging +import time + +import click +from celery import shared_task +from werkzeug.exceptions import NotFound + +from core.data_source.notion import NotionPageReader +from core.index.keyword_table_index import KeywordTableIndex +from core.index.vector_index import VectorIndex +from core.indexing_runner import IndexingRunner, DocumentIsPausedException +from core.llm.error import ProviderTokenNotInitError +from extensions.ext_database import db +from models.dataset import Document, Dataset, DocumentSegment +from models.source import DataSourceBinding + + +@shared_task +def document_indexing_sync_task(dataset_id: str, document_id: str): + """ + Async update document + :param dataset_id: + :param document_id: + + Usage: document_indexing_sync_task.delay(dataset_id, document_id) + """ + logging.info(click.style('Start sync document: {}'.format(document_id), fg='green')) + start_at = time.perf_counter() + + document = db.session.query(Document).filter( + Document.id == document_id, + Document.dataset_id == dataset_id + ).first() + + if not document: + raise NotFound('Document not found') + + data_source_info = document.data_source_info_dict + if document.data_source_type == 'notion_import': + if not data_source_info or 'notion_page_id' not in data_source_info \ + or 'notion_workspace_id' not in data_source_info: + raise ValueError("no notion page found") + workspace_id = data_source_info['notion_workspace_id'] + page_id = data_source_info['notion_page_id'] + page_edited_time = data_source_info['last_edited_time'] + data_source_binding = DataSourceBinding.query.filter( + db.and_( + DataSourceBinding.tenant_id == document.tenant_id, + DataSourceBinding.provider == 'notion', + DataSourceBinding.disabled == False, + DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' + ) + ).first() + if not data_source_binding: + raise ValueError('Data source binding not found.') + reader = NotionPageReader(integration_token=data_source_binding.access_token) + last_edited_time = reader.get_page_last_edited_time(page_id) + # check the page is updated + if last_edited_time != page_edited_time: + document.indexing_status = 'parsing' + document.processing_started_at = datetime.datetime.utcnow() + db.session.commit() + + # delete all document segment and index + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise Exception('Dataset not found') + + vector_index = VectorIndex(dataset=dataset) + keyword_table_index = KeywordTableIndex(dataset=dataset) + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + index_node_ids = [segment.index_node_id for segment in segments] + + # delete from vector index + vector_index.del_nodes(index_node_ids) + + # delete from keyword index + if index_node_ids: + keyword_table_index.del_nodes(index_node_ids) + + for segment in segments: + db.session.delete(segment) + + end_at = time.perf_counter() + logging.info( + click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + except Exception: + logging.exception("Cleaned document when document update data source or process rule failed") + try: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + end_at = time.perf_counter() + logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + except DocumentIsPausedException: + logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow')) + except ProviderTokenNotInitError as e: + document.indexing_status = 'error' + document.error = str(e.description) + document.stopped_at = datetime.datetime.utcnow() + db.session.commit() + except Exception as e: + logging.exception("consume update document failed") + document.indexing_status = 'error' + document.error = str(e) + document.stopped_at = datetime.datetime.utcnow() + db.session.commit() diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 59bbd4dc9..211d110fa 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -13,32 +13,36 @@ from models.dataset import Document @shared_task -def document_indexing_task(dataset_id: str, document_id: str): +def document_indexing_task(dataset_id: str, document_ids: list): """ Async process document :param dataset_id: - :param document_id: + :param document_ids: Usage: document_indexing_task.delay(dataset_id, document_id) """ - logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) - start_at = time.perf_counter() + documents = [] + for document_id in document_ids: + logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + start_at = time.perf_counter() - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = db.session.query(Document).filter( + Document.id == document_id, + Document.dataset_id == dataset_id + ).first() - if not document: - raise NotFound('Document not found') + if not document: + raise NotFound('Document not found') - document.indexing_status = 'parsing' - document.processing_started_at = datetime.datetime.utcnow() + document.indexing_status = 'parsing' + document.processing_started_at = datetime.datetime.utcnow() + documents.append(document) + db.session.add(document) db.session.commit() try: indexing_runner = IndexingRunner() - indexing_runner.run(document) + indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) except DocumentIsPausedException: diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 2aa98fb7b..8fee81f32 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -67,7 +67,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): logging.exception("Cleaned document when document update data source or process rule failed") try: indexing_runner = IndexingRunner() - indexing_runner.run(document) + indexing_runner.run([document]) end_at = time.perf_counter() logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) except DocumentIsPausedException: diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index c1a5d4336..3ab48e8a4 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -34,7 +34,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): try: indexing_runner = IndexingRunner() if document.indexing_status in ["waiting", "parsing", "cleaning"]: - indexing_runner.run(document) + indexing_runner.run([document]) elif document.indexing_status == "splitting": indexing_runner.run_in_splitting_status(document) elif document.indexing_status == "indexing": diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx index b05f8ce1f..e44d371b6 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx @@ -28,6 +28,7 @@ import Indicator from '@/app/components/header/indicator' import AppIcon from '@/app/components/base/app-icon' import Loading from '@/app/components/base/loading' import DatasetDetailContext from '@/context/dataset-detail' +import { DataSourceType } from '@/models/datasets' // import { fetchDatasetDetail } from '@/service/datasets' @@ -162,7 +163,7 @@ const DatasetDetailLayout: FC = (props) => { desc={datasetRes?.description || '--'} navigation={navigation} extraInfo={} - iconType='dataset' + iconType={datasetRes?.data_source_type === DataSourceType.NOTION ? 'notion' : 'dataset'} />} +const NotionSvg = + + + + + + + + + + + + const ICON_MAP = { app: , api: , dataset: , webapp: , + notion: , } export default function AppBasic({ icon, icon_background, name, type, hoverTip, textStyle, iconType = 'app' }: IAppBasicProps) { diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index eb3a44455..86f66df9d 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -4,7 +4,7 @@ import NavLink from './navLink' import AppBasic from './basic' export type IAppDetailNavProps = { - iconType?: 'app' | 'dataset' + iconType?: 'app' | 'dataset' | 'notion' title: string desc: string icon: string @@ -18,7 +18,6 @@ export type IAppDetailNavProps = { extraInfo?: React.ReactNode } - const AppDetailNav: FC = ({ title, desc, icon, icon_background, navigation, extraInfo, iconType = 'app' }) => { return (
diff --git a/web/app/components/base/checkbox/assets/check.svg b/web/app/components/base/checkbox/assets/check.svg new file mode 100644 index 000000000..f1f635ed7 --- /dev/null +++ b/web/app/components/base/checkbox/assets/check.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/base/checkbox/index.module.css b/web/app/components/base/checkbox/index.module.css new file mode 100644 index 000000000..102272e13 --- /dev/null +++ b/web/app/components/base/checkbox/index.module.css @@ -0,0 +1,9 @@ +.wrapper { + border-color: #d0d5dd; +} + +.checked { + background: #155eef url(./assets/check.svg) center center no-repeat; + background-size: 12px 12px; + border-color: #155eef; +} \ No newline at end of file diff --git a/web/app/components/base/checkbox/index.tsx b/web/app/components/base/checkbox/index.tsx new file mode 100644 index 000000000..da72116ce --- /dev/null +++ b/web/app/components/base/checkbox/index.tsx @@ -0,0 +1,19 @@ +import cn from 'classnames' +import s from './index.module.css' + +type CheckboxProps = { + checked?: boolean + onCheck?: () => void + className?: string +} + +const Checkbox = ({ checked, onCheck, className }: CheckboxProps) => { + return ( +
+ ) +} + +export default Checkbox diff --git a/web/app/components/base/notion-icon/index.module.css b/web/app/components/base/notion-icon/index.module.css new file mode 100644 index 000000000..2947260fd --- /dev/null +++ b/web/app/components/base/notion-icon/index.module.css @@ -0,0 +1,6 @@ +.default-page-icon { + width: 20px; + height: 20px; + background: url(../notion-page-selector/assets/notion-page.svg) center center no-repeat; + background-size: cover; +} \ No newline at end of file diff --git a/web/app/components/base/notion-icon/index.tsx b/web/app/components/base/notion-icon/index.tsx new file mode 100644 index 000000000..bc56094ab --- /dev/null +++ b/web/app/components/base/notion-icon/index.tsx @@ -0,0 +1,58 @@ +import cn from 'classnames' +import s from './index.module.css' +import type { DataSourceNotionPage } from '@/models/common' + +type IconTypes = 'workspace' | 'page' +type NotionIconProps = { + type?: IconTypes + name?: string | null + className?: string + src?: string | null | Pick['page_icon'] +} +const NotionIcon = ({ + type = 'workspace', + src, + name, + className, +}: NotionIconProps) => { + if (type === 'workspace') { + if (typeof src === 'string') { + if (src.startsWith('https://') || src.startsWith('http://')) { + return ( + workspace icon + ) + } + return ( +
{src}
+ ) + } + return ( +
{name?.[0].toLocaleUpperCase()}
+ ) + } + + if (typeof src === 'object' && src !== null) { + if (src?.type === 'url') { + return ( + page icon + ) + } + return ( +
{src?.emoji}
+ ) + } + + return ( +
+ ) +} + +export default NotionIcon diff --git a/web/app/components/base/notion-page-selector/assets/clear.svg b/web/app/components/base/notion-page-selector/assets/clear.svg new file mode 100644 index 000000000..3d1bbf55f --- /dev/null +++ b/web/app/components/base/notion-page-selector/assets/clear.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/base/notion-page-selector/assets/down-arrow.svg b/web/app/components/base/notion-page-selector/assets/down-arrow.svg new file mode 100644 index 000000000..0676e96e3 --- /dev/null +++ b/web/app/components/base/notion-page-selector/assets/down-arrow.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/base/notion-page-selector/assets/notion-empty-page.svg b/web/app/components/base/notion-page-selector/assets/notion-empty-page.svg new file mode 100644 index 000000000..7493621ac --- /dev/null +++ b/web/app/components/base/notion-page-selector/assets/notion-empty-page.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/base/notion-page-selector/assets/notion-page.svg b/web/app/components/base/notion-page-selector/assets/notion-page.svg new file mode 100644 index 000000000..237fc2ee5 --- /dev/null +++ b/web/app/components/base/notion-page-selector/assets/notion-page.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/base/notion-page-selector/assets/search.svg b/web/app/components/base/notion-page-selector/assets/search.svg new file mode 100644 index 000000000..1d083d05c --- /dev/null +++ b/web/app/components/base/notion-page-selector/assets/search.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/web/app/components/base/notion-page-selector/assets/setting.svg b/web/app/components/base/notion-page-selector/assets/setting.svg new file mode 100644 index 000000000..6d3ecf53c --- /dev/null +++ b/web/app/components/base/notion-page-selector/assets/setting.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/web/app/components/base/notion-page-selector/base.module.css b/web/app/components/base/notion-page-selector/base.module.css new file mode 100644 index 000000000..d9aa9fe35 --- /dev/null +++ b/web/app/components/base/notion-page-selector/base.module.css @@ -0,0 +1,4 @@ +.setting-icon { + background: url(./assets/setting.svg) center center no-repeat; + background-size: 14px 14px; +} \ No newline at end of file diff --git a/web/app/components/base/notion-page-selector/base.tsx b/web/app/components/base/notion-page-selector/base.tsx new file mode 100644 index 000000000..276b66e1a --- /dev/null +++ b/web/app/components/base/notion-page-selector/base.tsx @@ -0,0 +1,141 @@ +import { useCallback, useEffect, useMemo, useState } from 'react' +import useSWR from 'swr' +import cn from 'classnames' +import s from './base.module.css' +import WorkspaceSelector from './workspace-selector' +import SearchInput from './search-input' +import PageSelector from './page-selector' +import { preImportNotionPages } from '@/service/datasets' +import AccountSetting from '@/app/components/header/account-setting' +import { NotionConnector } from '@/app/components/datasets/create/step-one' +import type { DataSourceNotionPage, DataSourceNotionPageMap, DataSourceNotionWorkspace } from '@/models/common' + +export type NotionPageSelectorValue = DataSourceNotionPage & { workspace_id: string } + +type NotionPageSelectorProps = { + value?: string[] + onSelect: (selectedPages: NotionPageSelectorValue[]) => void + canPreview?: boolean + previewPageId?: string + onPreview?: (selectedPage: NotionPageSelectorValue) => void + datasetId?: string +} + +const NotionPageSelector = ({ + value, + onSelect, + canPreview, + previewPageId, + onPreview, + datasetId = '', +}: NotionPageSelectorProps) => { + const { data, mutate } = useSWR({ url: '/notion/pre-import/pages', datasetId }, preImportNotionPages) + const [prevData, setPrevData] = useState(data) + const [searchValue, setSearchValue] = useState('') + const [showDataSourceSetting, setShowDataSourceSetting] = useState(false) + const [currentWorkspaceId, setCurrentWorkspaceId] = useState('') + + const notionWorkspaces = useMemo(() => { + return data?.notion_info || [] + }, [data?.notion_info]) + const firstWorkspaceId = notionWorkspaces[0]?.workspace_id + const currentWorkspace = notionWorkspaces.find(workspace => workspace.workspace_id === currentWorkspaceId) + + const getPagesMapAndSelectedPagesId: [DataSourceNotionPageMap, Set] = useMemo(() => { + const selectedPagesId = new Set() + const pagesMap = notionWorkspaces.reduce((prev: DataSourceNotionPageMap, next: DataSourceNotionWorkspace) => { + next.pages.forEach((page) => { + if (page.is_bound) + selectedPagesId.add(page.page_id) + prev[page.page_id] = { + ...page, + workspace_id: next.workspace_id, + } + }) + + return prev + }, {}) + return [pagesMap, selectedPagesId] + }, [notionWorkspaces]) + const defaultSelectedPagesId = [...Array.from(getPagesMapAndSelectedPagesId[1]), ...(value || [])] + const [selectedPagesId, setSelectedPagesId] = useState>(new Set(defaultSelectedPagesId)) + + if (prevData !== data) { + setPrevData(data) + setSelectedPagesId(new Set(defaultSelectedPagesId)) + } + + const handleSearchValueChange = useCallback((value: string) => { + setSearchValue(value) + }, []) + const handleSelectWorkspace = useCallback((workspaceId: string) => { + setCurrentWorkspaceId(workspaceId) + }, []) + const handleSelecPages = (selectedPagesId: Set) => { + setSelectedPagesId(new Set(Array.from(selectedPagesId))) + const selectedPages = Array.from(selectedPagesId).map(pageId => getPagesMapAndSelectedPagesId[0][pageId]) + onSelect(selectedPages) + } + const handlePreviewPage = (previewPageId: string) => { + if (onPreview) + onPreview(getPagesMapAndSelectedPagesId[0][previewPageId]) + } + + useEffect(() => { + setCurrentWorkspaceId(firstWorkspaceId) + }, [firstWorkspaceId]) + + return ( +
+ { + data?.notion_info?.length + ? ( + <> +
+ +
+
setShowDataSourceSetting(true)} + /> +
+ +
+
+ +
+ + ) + : ( + setShowDataSourceSetting(true)} /> + ) + } + { + showDataSourceSetting && ( + { + setShowDataSourceSetting(false) + mutate() + }} /> + ) + } +
+ ) +} + +export default NotionPageSelector diff --git a/web/app/components/base/notion-page-selector/index.tsx b/web/app/components/base/notion-page-selector/index.tsx new file mode 100644 index 000000000..0a44c70fe --- /dev/null +++ b/web/app/components/base/notion-page-selector/index.tsx @@ -0,0 +1,2 @@ +export { default as NotionPageSelectorModal } from './notion-page-selector-modal' +export { default as NotionPageSelector } from './base' diff --git a/web/app/components/base/notion-page-selector/notion-page-selector-modal/index.module.css b/web/app/components/base/notion-page-selector/notion-page-selector-modal/index.module.css new file mode 100644 index 000000000..ed9091601 --- /dev/null +++ b/web/app/components/base/notion-page-selector/notion-page-selector-modal/index.module.css @@ -0,0 +1,28 @@ +.modal { + width: 600px !important; + max-width: 600px !important; + padding: 24px 32px !important; +} + +.operate { + padding: 0 8px; + min-width: 96px; + height: 36px; + line-height: 36px; + text-align: center; + background-color: #ffffff; + box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); + border-radius: 8px; + border: 0.5px solid #eaecf0; + font-size: 14px; + font-weight: 500; + color: #667085; + cursor: pointer; +} + +.operate-save { + margin-left: 8px; + border-color: #155eef; + background-color: #155eef; + color: #ffffff; +} \ No newline at end of file diff --git a/web/app/components/base/notion-page-selector/notion-page-selector-modal/index.tsx b/web/app/components/base/notion-page-selector/notion-page-selector-modal/index.tsx new file mode 100644 index 000000000..944eb4781 --- /dev/null +++ b/web/app/components/base/notion-page-selector/notion-page-selector-modal/index.tsx @@ -0,0 +1,62 @@ +import { useState } from 'react' +import { useTranslation } from 'react-i18next' +import cn from 'classnames' +import { XMarkIcon } from '@heroicons/react/24/outline' +import NotionPageSelector from '../base' +import type { NotionPageSelectorValue } from '../base' +import s from './index.module.css' +import Modal from '@/app/components/base/modal' + +type NotionPageSelectorModalProps = { + isShow: boolean + onClose: () => void + onSave: (selectedPages: NotionPageSelectorValue[]) => void + datasetId: string +} +const NotionPageSelectorModal = ({ + isShow, + onClose, + onSave, + datasetId, +}: NotionPageSelectorModalProps) => { + const { t } = useTranslation() + const [selectedPages, setSelectedPages] = useState([]) + + const handleClose = () => { + onClose() + } + const handleSelectPage = (newSelectedPages: NotionPageSelectorValue[]) => { + setSelectedPages(newSelectedPages) + } + const handleSave = () => { + onSave(selectedPages) + } + + return ( + {}} + > +
+
{t('common.dataSource.notion.selector.addPages')}
+
+ +
+
+ +
+
{t('common.operation.cancel')}
+
{t('common.operation.save')}
+
+
+ ) +} + +export default NotionPageSelectorModal diff --git a/web/app/components/base/notion-page-selector/page-selector/index.module.css b/web/app/components/base/notion-page-selector/page-selector/index.module.css new file mode 100644 index 000000000..1542095b8 --- /dev/null +++ b/web/app/components/base/notion-page-selector/page-selector/index.module.css @@ -0,0 +1,17 @@ +.arrow { + width: 20px; + height: 20px; + background: url(../assets/down-arrow.svg) center center no-repeat; + background-size: 16px 16px; + transform: rotate(-90deg); +} + +.arrow-expand { + transform: rotate(0); +} + +.preview-item { + background-color: #eff4ff; + border: 1px solid #D1E0FF; + box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); +} \ No newline at end of file diff --git a/web/app/components/base/notion-page-selector/page-selector/index.tsx b/web/app/components/base/notion-page-selector/page-selector/index.tsx new file mode 100644 index 000000000..101640172 --- /dev/null +++ b/web/app/components/base/notion-page-selector/page-selector/index.tsx @@ -0,0 +1,299 @@ +import { memo, useMemo, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { FixedSizeList as List, areEqual } from 'react-window' +import type { ListChildComponentProps } from 'react-window' +import cn from 'classnames' +import Checkbox from '../../checkbox' +import NotionIcon from '../../notion-icon' +import s from './index.module.css' +import type { DataSourceNotionPage, DataSourceNotionPageMap } from '@/models/common' + +type PageSelectorProps = { + value: Set + searchValue: string + pagesMap: DataSourceNotionPageMap + list: DataSourceNotionPage[] + onSelect: (selectedPagesId: Set) => void + canPreview?: boolean + previewPageId?: string + onPreview?: (selectedPageId: string) => void +} +type NotionPageTreeItem = { + children: Set + descendants: Set + deepth: number + ancestors: string[] +} & DataSourceNotionPage +type NotionPageTreeMap = Record +type NotionPageItem = { + expand: boolean + deepth: number +} & DataSourceNotionPage + +const recursivePushInParentDescendants = ( + pagesMap: DataSourceNotionPageMap, + listTreeMap: NotionPageTreeMap, + current: NotionPageTreeItem, + leafItem: NotionPageTreeItem, +) => { + const parentId = current.parent_id + const pageId = current.page_id + + if (!parentId || !pageId) + return + + if (parentId !== 'root' && pagesMap[parentId]) { + if (!listTreeMap[parentId]) { + const children = new Set([pageId]) + const descendants = new Set([pageId, leafItem.page_id]) + listTreeMap[parentId] = { + ...pagesMap[parentId], + children, + descendants, + deepth: 0, + ancestors: [], + } + } + else { + listTreeMap[parentId].children.add(pageId) + listTreeMap[parentId].descendants.add(pageId) + listTreeMap[parentId].descendants.add(leafItem.page_id) + } + leafItem.deepth++ + leafItem.ancestors.unshift(listTreeMap[parentId].page_name) + + if (listTreeMap[parentId].parent_id !== 'root') + recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap[parentId], leafItem) + } +} + +const Item = memo(({ index, style, data }: ListChildComponentProps<{ + dataList: NotionPageItem[] + handleToggle: (index: number) => void + checkedIds: Set + handleCheck: (index: number) => void + canPreview?: boolean + handlePreview: (index: number) => void + listMapWithChildrenAndDescendants: NotionPageTreeMap + searchValue: string + previewPageId: string + pagesMap: DataSourceNotionPageMap +}>) => { + const { t } = useTranslation() + const { dataList, handleToggle, checkedIds, handleCheck, canPreview, handlePreview, listMapWithChildrenAndDescendants, searchValue, previewPageId, pagesMap } = data + const current = dataList[index] + const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[current.page_id] + const hasChild = currentWithChildrenAndDescendants.descendants.size > 0 + const ancestors = currentWithChildrenAndDescendants.ancestors + const breadCrumbs = ancestors.length ? [...ancestors, current.page_name] : [current.page_name] + + const renderArrow = () => { + if (hasChild) { + return ( +
handleToggle(index)} + /> + ) + } + if (current.parent_id === 'root' || !pagesMap[current.parent_id]) { + return ( +
+ ) + } + return ( +
+ ) + } + + return ( +
+ handleCheck(index)} + /> + {!searchValue && renderArrow()} + +
+ {current.page_name} +
+ { + canPreview && ( +
handlePreview(index)}> + {t('common.dataSource.notion.selector.preview')} +
+ ) + } + { + searchValue && ( +
+ {breadCrumbs.join(' / ')} +
+ ) + } +
+ ) +}, areEqual) + +const PageSelector = ({ + value, + searchValue, + pagesMap, + list, + onSelect, + canPreview = true, + previewPageId, + onPreview, +}: PageSelectorProps) => { + const { t } = useTranslation() + const [prevDataList, setPrevDataList] = useState(list) + const [dataList, setDataList] = useState([]) + const [localPreviewPageId, setLocalPreviewPageId] = useState('') + if (prevDataList !== list) { + setPrevDataList(list) + setDataList(list.filter(item => item.parent_id === 'root' || !pagesMap[item.parent_id]).map((item) => { + return { + ...item, + expand: false, + deepth: 0, + } + })) + } + const searchDataList = list.filter((item) => { + return item.page_name.includes(searchValue) + }).map((item) => { + return { + ...item, + expand: false, + deepth: 0, + } + }) + const currentDataList = searchValue ? searchDataList : dataList + const currentPreviewPageId = previewPageId === undefined ? localPreviewPageId : previewPageId + + const listMapWithChildrenAndDescendants = useMemo(() => { + return list.reduce((prev: NotionPageTreeMap, next: DataSourceNotionPage) => { + const pageId = next.page_id + if (!prev[pageId]) + prev[pageId] = { ...next, children: new Set(), descendants: new Set(), deepth: 0, ancestors: [] } + + recursivePushInParentDescendants(pagesMap, prev, prev[pageId], prev[pageId]) + return prev + }, {}) + }, [list, pagesMap]) + + const handleToggle = (index: number) => { + const current = dataList[index] + const pageId = current.page_id + const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId] + const descendantsIds = Array.from(currentWithChildrenAndDescendants.descendants) + const childrenIds = Array.from(currentWithChildrenAndDescendants.children) + let newDataList = [] + + if (current.expand) { + current.expand = false + + newDataList = [...dataList.filter(item => !descendantsIds.includes(item.page_id))] + } + else { + current.expand = true + + newDataList = [ + ...dataList.slice(0, index + 1), + ...childrenIds.map(item => ({ + ...pagesMap[item], + expand: false, + deepth: listMapWithChildrenAndDescendants[item].deepth, + })), + ...dataList.slice(index + 1)] + } + setDataList(newDataList) + } + + const handleCheck = (index: number) => { + const current = currentDataList[index] + const pageId = current.page_id + const currentWithChildrenAndDescendants = listMapWithChildrenAndDescendants[pageId] + + if (value.has(pageId)) { + if (!searchValue) { + for (const item of currentWithChildrenAndDescendants.descendants) + value.delete(item) + } + + value.delete(pageId) + } + else { + if (!searchValue) { + for (const item of currentWithChildrenAndDescendants.descendants) + value.add(item) + } + + value.add(pageId) + } + + onSelect(new Set([...value])) + } + + const handlePreview = (index: number) => { + const current = currentDataList[index] + const pageId = current.page_id + + setLocalPreviewPageId(pageId) + + if (onPreview) + onPreview(pageId) + } + + if (!currentDataList.length) { + return ( +
+ {t('common.dataSource.notion.selector.noSearchResult')} +
+ ) + } + + return ( + data.dataList[index].page_id} + itemData={{ + dataList: currentDataList, + handleToggle, + checkedIds: value, + handleCheck, + canPreview, + handlePreview, + listMapWithChildrenAndDescendants, + searchValue, + previewPageId: currentPreviewPageId, + pagesMap, + }} + > + {Item} + + ) +} + +export default PageSelector diff --git a/web/app/components/base/notion-page-selector/search-input/index.module.css b/web/app/components/base/notion-page-selector/search-input/index.module.css new file mode 100644 index 000000000..a65b7d539 --- /dev/null +++ b/web/app/components/base/notion-page-selector/search-input/index.module.css @@ -0,0 +1,15 @@ +.search-icon { + background: url(../assets/search.svg) center center; + background-size: 14px 14px; +} + +.clear-icon { + background: url(../assets/clear.svg) center center; + background-size: contain; +} + +.input-wrapper { + flex-basis: 200px; + width: 0; + box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); +} \ No newline at end of file diff --git a/web/app/components/base/notion-page-selector/search-input/index.tsx b/web/app/components/base/notion-page-selector/search-input/index.tsx new file mode 100644 index 000000000..1a41a4c09 --- /dev/null +++ b/web/app/components/base/notion-page-selector/search-input/index.tsx @@ -0,0 +1,42 @@ +import { useCallback } from 'react' +import type { ChangeEvent } from 'react' +import { useTranslation } from 'react-i18next' +import cn from 'classnames' +import s from './index.module.css' + +type SearchInputProps = { + value: string + onChange: (v: string) => void +} +const SearchInput = ({ + value, + onChange, +}: SearchInputProps) => { + const { t } = useTranslation() + + const handleClear = useCallback(() => { + onChange('') + }, [onChange]) + + return ( +
+
+ ) => onChange(e.target.value)} + placeholder={t('common.dataSource.notion.selector.searchPages') || ''} + /> + { + value && ( +
+ ) + } +
+ ) +} + +export default SearchInput diff --git a/web/app/components/base/notion-page-selector/workspace-selector/index.module.css b/web/app/components/base/notion-page-selector/workspace-selector/index.module.css new file mode 100644 index 000000000..b68e1561e --- /dev/null +++ b/web/app/components/base/notion-page-selector/workspace-selector/index.module.css @@ -0,0 +1,9 @@ +.down-arrow { + background: url(../assets/down-arrow.svg) center center no-repeat; + background-size: cover; +} + +.popup { + box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03); + z-index: 10; +} \ No newline at end of file diff --git a/web/app/components/base/notion-page-selector/workspace-selector/index.tsx b/web/app/components/base/notion-page-selector/workspace-selector/index.tsx new file mode 100644 index 000000000..bc340e43a --- /dev/null +++ b/web/app/components/base/notion-page-selector/workspace-selector/index.tsx @@ -0,0 +1,84 @@ +'use client' +import { useTranslation } from 'react-i18next' +import { Fragment } from 'react' +import { Menu, Transition } from '@headlessui/react' +import cn from 'classnames' +import NotionIcon from '../../notion-icon' +import s from './index.module.css' +import type { DataSourceNotionWorkspace } from '@/models/common' + +type WorkspaceSelectorProps = { + value: string + items: Omit[] + onSelect: (v: string) => void +} +export default function WorkspaceSelector({ + value, + items, + onSelect, +}: WorkspaceSelectorProps) { + const { t } = useTranslation() + const currentWorkspace = items.find(item => item.workspace_id === value) + + return ( + + { + ({ open }) => ( + <> + + +
{currentWorkspace?.workspace_name}
+
{currentWorkspace?.pages.length}
+
+ + + +
+ { + items.map(item => ( + +
onSelect(item.workspace_id)} + > + +
{item.workspace_name}
+
+ {item.pages.length} {t('common.dataSource.notion.selector.pageSelected')} +
+
+
+ )) + } +
+
+
+ + ) + } +
+ ) +} diff --git a/web/app/components/base/progress-bar/index.tsx b/web/app/components/base/progress-bar/index.tsx new file mode 100644 index 000000000..f0fd2a830 --- /dev/null +++ b/web/app/components/base/progress-bar/index.tsx @@ -0,0 +1,20 @@ +type ProgressBarProps = { + percent: number +} +const ProgressBar = ({ + percent = 0, +}: ProgressBarProps) => { + return ( +
+
+
+
+
{percent}%
+
+ ) +} + +export default ProgressBar diff --git a/web/app/components/datasets/create/assets/Icon-3-dots.svg b/web/app/components/datasets/create/assets/Icon-3-dots.svg new file mode 100644 index 000000000..0955e5d82 --- /dev/null +++ b/web/app/components/datasets/create/assets/Icon-3-dots.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/datasets/create/assets/normal.svg b/web/app/components/datasets/create/assets/normal.svg new file mode 100644 index 000000000..1d94adffe --- /dev/null +++ b/web/app/components/datasets/create/assets/normal.svg @@ -0,0 +1,4 @@ + + + + diff --git a/web/app/components/datasets/create/assets/star.svg b/web/app/components/datasets/create/assets/star.svg new file mode 100644 index 000000000..3ffda5794 --- /dev/null +++ b/web/app/components/datasets/create/assets/star.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/web/app/components/datasets/create/embedding-process/index.module.css b/web/app/components/datasets/create/embedding-process/index.module.css new file mode 100644 index 000000000..9269ce201 --- /dev/null +++ b/web/app/components/datasets/create/embedding-process/index.module.css @@ -0,0 +1,111 @@ +.progressContainer { + @apply relative pb-4 w-full; + border-bottom: 0.5px solid #EAECF0; +} +.sourceItem { + position: relative; + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 4px; + padding: 0 4px; + height: 24px; + background: #EFF4FF; + border-radius: 6px; + overflow: hidden; +} +.sourceItem.error { + background: #FEE4E2; +} +.sourceItem.success { + background: #D1FADF; +} +.progressbar { + position: absolute; + top: 0; + left: 0; + height: 100%; + background-color: #B2CCFF; +} +.sourceItem .info { + display: flex; + align-items: center; +} +.sourceItem .info .name { + font-weight: 500; + font-size: 12px; + line-height: 18px; + color: #101828; +} +.sourceItem.success .info .name { + color: #05603A; +} +.sourceItem .percent { + font-weight: 500; + font-size: 12px; + line-height: 18px; + color: #344054; +} +.sourceItem .error { + color: #D92D20; +} +.sourceItem .success { + color: #05603A; +} + + +.cost { + @apply flex justify-between items-center text-xs text-gray-700; +} +.embeddingStatus { + @apply flex items-center justify-between text-gray-900 font-medium text-sm mr-2; +} +.commonIcon { + @apply w-3 h-3 mr-1 inline-block align-middle; +} +.highIcon { + mask-image: url(../assets/star.svg); + @apply bg-orange-500; +} +.economyIcon { + background-color: #444ce7; + mask-image: url(../assets/normal.svg); +} +.tokens { + @apply text-xs font-medium px-1; +} +.price { + color: #f79009; + @apply text-xs font-medium; +} + +.fileIcon { + @apply w-4 h-4 mr-1 bg-center bg-no-repeat; + background-image: url(../assets/unknow.svg); + background-size: 16px; +} +.fileIcon.csv { + background-image: url(../assets/csv.svg); +} + +.fileIcon.xlsx, +.fileIcon.xls { + background-image: url(../assets/xlsx.svg); +} +.fileIcon.pdf { + background-image: url(../assets/pdf.svg); +} +.fileIcon.html, +.fileIcon.htm { + background-image: url(../assets/html.svg); +} +.fileIcon.md, +.fileIcon.markdown { + background-image: url(../assets/md.svg); +} +.fileIcon.txt { + background-image: url(../assets/txt.svg); +} +.fileIcon.json { + background-image: url(../assets/json.svg); +} diff --git a/web/app/components/datasets/create/embedding-process/index.tsx b/web/app/components/datasets/create/embedding-process/index.tsx new file mode 100644 index 000000000..6b311b549 --- /dev/null +++ b/web/app/components/datasets/create/embedding-process/index.tsx @@ -0,0 +1,242 @@ +import type { FC } from 'react' +import React, { useCallback, useEffect, useMemo } from 'react' +import useSWR from 'swr' +import { useRouter } from 'next/navigation' +import { useTranslation } from 'react-i18next' +import { omit } from 'lodash-es' +import { ArrowRightIcon } from '@heroicons/react/24/solid' +import { useGetState } from 'ahooks' +import cn from 'classnames' +import s from './index.module.css' +import { FieldInfo } from '@/app/components/datasets/documents/detail/metadata' +import Button from '@/app/components/base/button' +import type { FullDocumentDetail, IndexingStatusResponse, ProcessRuleResponse } from '@/models/datasets' +import { formatNumber } from '@/utils/format' +import { fetchIndexingStatusBatch as doFetchIndexingStatus, fetchIndexingEstimateBatch, fetchProcessRule } from '@/service/datasets' +import { DataSourceType } from '@/models/datasets' +import NotionIcon from '@/app/components/base/notion-icon' + +type Props = { + datasetId: string + batchId: string + documents?: FullDocumentDetail[] + indexingType?: string +} + +const RuleDetail: FC<{ sourceData?: ProcessRuleResponse }> = ({ sourceData }) => { + const { t } = useTranslation() + + const segmentationRuleMap = { + mode: t('datasetDocuments.embedding.mode'), + segmentLength: t('datasetDocuments.embedding.segmentLength'), + textCleaning: t('datasetDocuments.embedding.textCleaning'), + } + + const getRuleName = (key: string) => { + if (key === 'remove_extra_spaces') + return t('datasetCreation.stepTwo.removeExtraSpaces') + + if (key === 'remove_urls_emails') + return t('datasetCreation.stepTwo.removeUrlEmails') + + if (key === 'remove_stopwords') + return t('datasetCreation.stepTwo.removeStopwords') + } + + const getValue = useCallback((field: string) => { + let value: string | number | undefined = '-' + switch (field) { + case 'mode': + value = sourceData?.mode === 'automatic' ? (t('datasetDocuments.embedding.automatic') as string) : (t('datasetDocuments.embedding.custom') as string) + break + case 'segmentLength': + value = sourceData?.rules?.segmentation?.max_tokens + break + default: + value = sourceData?.mode === 'automatic' + ? (t('datasetDocuments.embedding.automatic') as string) + // eslint-disable-next-line array-callback-return + : sourceData?.rules?.pre_processing_rules?.map((rule) => { + if (rule.enabled) + return getRuleName(rule.id) + }).filter(Boolean).join(';') + break + } + return value + }, [sourceData]) + + return
+ {Object.keys(segmentationRuleMap).map((field) => { + return + })} +
+} + +const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], indexingType }) => { + const { t } = useTranslation() + + const getFirstDocument = documents[0] + + const [indexingStatusBatchDetail, setIndexingStatusDetail, getIndexingStatusDetail] = useGetState([]) + const fetchIndexingStatus = async () => { + const status = await doFetchIndexingStatus({ datasetId, batchId }) + setIndexingStatusDetail(status.data) + } + + const [runId, setRunId, getRunId] = useGetState(null) + + const stopQueryStatus = () => { + clearInterval(getRunId()) + } + + const startQueryStatus = () => { + const runId = setInterval(() => { + const indexingStatusBatchDetail = getIndexingStatusDetail() + const isCompleted = indexingStatusBatchDetail.every(indexingStatusDetail => ['completed', 'error'].includes(indexingStatusDetail.indexing_status)) + if (isCompleted) { + stopQueryStatus() + return + } + fetchIndexingStatus() + }, 2500) + setRunId(runId) + } + + useEffect(() => { + fetchIndexingStatus() + startQueryStatus() + return () => { + stopQueryStatus() + } + }, []) + + // get rule + const { data: ruleDetail, error: ruleError } = useSWR({ + action: 'fetchProcessRule', + params: { documentId: getFirstDocument.id }, + }, apiParams => fetchProcessRule(omit(apiParams, 'action')), { + revalidateOnFocus: false, + }) + // get cost + const { data: indexingEstimateDetail, error: indexingEstimateErr } = useSWR({ + action: 'fetchIndexingEstimateBatch', + datasetId, + batchId, + }, apiParams => fetchIndexingEstimateBatch(omit(apiParams, 'action')), { + revalidateOnFocus: false, + }) + + const router = useRouter() + const navToDocumentList = () => { + router.push(`/datasets/${datasetId}/documents`) + } + + const isEmbedding = useMemo(() => { + return indexingStatusBatchDetail.some((indexingStatusDetail: { indexing_status: any }) => ['indexing', 'splitting', 'parsing', 'cleaning'].includes(indexingStatusDetail?.indexing_status || '')) + }, [indexingStatusBatchDetail]) + const isEmbeddingCompleted = useMemo(() => { + return indexingStatusBatchDetail.every((indexingStatusDetail: { indexing_status: any }) => ['completed', 'error'].includes(indexingStatusDetail?.indexing_status || '')) + }, [indexingStatusBatchDetail]) + + const getSourceName = (id: string) => { + const doc = documents.find(document => document.id === id) + return doc?.name + } + const getFileType = (name?: string) => name?.split('.').pop() || 'txt' + const getSourcePercent = (detail: IndexingStatusResponse) => { + const completedCount = detail.completed_segments || 0 + const totalCount = detail.total_segments || 0 + if (totalCount === 0) + return 0 + const percent = Math.round(completedCount * 100 / totalCount) + return percent > 100 ? 100 : percent + } + const getSourceType = (id: string) => { + const doc = documents.find(document => document.id === id) + return doc?.data_source_type as DataSourceType + } + const getIcon = (id: string) => { + const doc = documents.find(document => document.id === id) as any // TODO type fix + + return doc.data_source_info.notion_page_icon + } + const isSourceEmbedding = (detail: IndexingStatusResponse) => ['indexing', 'splitting', 'parsing', 'cleaning', 'waiting'].includes(detail.indexing_status || '') + + return ( + <> +
+
+ {isEmbedding && t('datasetDocuments.embedding.processing')} + {isEmbeddingCompleted && t('datasetDocuments.embedding.completed')} +
+
+ {indexingType === 'high_quaility' && ( +
+
+ {t('datasetDocuments.embedding.highQuality')} · {t('datasetDocuments.embedding.estimate')} + {formatNumber(indexingEstimateDetail?.tokens || 0)}tokens + (${formatNumber(indexingEstimateDetail?.total_price || 0)}) +
+ )} + {indexingType === 'economy' && ( +
+
+ {t('datasetDocuments.embedding.economy')} · {t('datasetDocuments.embedding.estimate')} + 0tokens +
+ )} +
+
+
+ {indexingStatusBatchDetail.map(indexingStatusDetail => ( +
+ {isSourceEmbedding(indexingStatusDetail) && ( +
+ )} +
+ {getSourceType(indexingStatusDetail.id) === DataSourceType.FILE && ( +
+ )} + {getSourceType(indexingStatusDetail.id) === DataSourceType.NOTION && ( + + )} +
{getSourceName(indexingStatusDetail.id)}
+
+
+ {isSourceEmbedding(indexingStatusDetail) && ( +
{`${getSourcePercent(indexingStatusDetail)}%`}
+ )} + {indexingStatusDetail.indexing_status === 'error' && ( +
Error
+ )} + {indexingStatusDetail.indexing_status === 'completed' && ( +
100%
+ )} +
+
+ ))} +
+ +
+ +
+ + ) +} + +export default EmbeddingProcess diff --git a/web/app/components/datasets/create/file-preview/index.module.css b/web/app/components/datasets/create/file-preview/index.module.css index f64f49364..d87522e6d 100644 --- a/web/app/components/datasets/create/file-preview/index.module.css +++ b/web/app/components/datasets/create/file-preview/index.module.css @@ -11,6 +11,9 @@ } .previewHeader .title { + display: flex; + justify-content: space-between; + align-items: center; color: #101828; font-weight: 600; font-size: 18px; diff --git a/web/app/components/datasets/create/file-preview/index.tsx b/web/app/components/datasets/create/file-preview/index.tsx index ad9f866c1..738013f19 100644 --- a/web/app/components/datasets/create/file-preview/index.tsx +++ b/web/app/components/datasets/create/file-preview/index.tsx @@ -1,18 +1,21 @@ 'use client' -import React, { useState, useEffect } from 'react' +import React, { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' +import cn from 'classnames' +import { XMarkIcon } from '@heroicons/react/20/solid' +import s from './index.module.css' import type { File } from '@/models/datasets' import { fetchFilePreview } from '@/service/common' -import cn from 'classnames' -import s from './index.module.css' - type IProps = { - file?: File, + file?: File + notionPage?: any + hidePreview: () => void } const FilePreview = ({ file, + hidePreview, }: IProps) => { const { t } = useTranslation() const [previewContent, setPreviewContent] = useState('') @@ -28,23 +31,27 @@ const FilePreview = ({ } const getFileName = (currentFile?: File) => { - if (!currentFile) { + if (!currentFile) return '' - } + const arr = currentFile.name.split('.') return arr.slice(0, -1).join() } useEffect(() => { - if (file) { + if (file) getPreviewContent(file.id) - } }, [file]) return (
-
{t('datasetCreation.stepOne.filePreview')}
+
+ {t('datasetCreation.stepOne.filePreview')} +
+ +
+
{getFileName(file)}.{file?.extension}
diff --git a/web/app/components/datasets/create/index.tsx b/web/app/components/datasets/create/index.tsx index 344e7ee86..54e55e930 100644 --- a/web/app/components/datasets/create/index.tsx +++ b/web/app/components/datasets/create/index.tsx @@ -1,32 +1,44 @@ 'use client' -import React, { useState, useCallback, useEffect } from 'react' +import React, { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' -import type { DataSet, File, createDocumentResponse } from '@/models/datasets' -import { fetchTenantInfo } from '@/service/common' -import { fetchDataDetail } from '@/service/datasets' - +import AppUnavailable from '../../base/app-unavailable' import StepsNavBar from './steps-nav-bar' import StepOne from './step-one' import StepTwo from './step-two' import StepThree from './step-three' +import { DataSourceType } from '@/models/datasets' +import type { DataSet, File, createDocumentResponse } from '@/models/datasets' +import { fetchDataSource, fetchTenantInfo } from '@/service/common' +import { fetchDataDetail } from '@/service/datasets' +import type { DataSourceNotionPage } from '@/models/common' + import AccountSetting from '@/app/components/header/account-setting' -import AppUnavailable from '../../base/app-unavailable' + +type Page = DataSourceNotionPage & { workspace_id: string } type DatasetUpdateFormProps = { - datasetId?: string; + datasetId?: string } const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { const { t } = useTranslation() const [hasSetAPIKEY, setHasSetAPIKEY] = useState(true) const [isShowSetAPIKey, { setTrue: showSetAPIKey, setFalse: hideSetAPIkey }] = useBoolean() + const [hasConnection, setHasConnection] = useState(true) + const [isShowDataSourceSetting, { setTrue: showDataSourceSetting, setFalse: hideDataSourceSetting }] = useBoolean() + const [dataSourceType, setDataSourceType] = useState(DataSourceType.FILE) const [step, setStep] = useState(1) const [indexingTypeCache, setIndexTypeCache] = useState('') const [file, setFile] = useState() const [result, setResult] = useState() const [hasError, setHasError] = useState(false) + const [notionPages, setNotionPages] = useState([]) + const updateNotionPages = (value: Page[]) => { + setNotionPages(value) + } + const updateFile = (file?: File) => { setFile(file) } @@ -50,9 +62,15 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { const hasSetKey = data.providers.some(({ is_valid }) => is_valid) setHasSetAPIKEY(hasSetKey) } + const checkNotionConnection = async () => { + const { data } = await fetchDataSource({ url: '/data-source/integrates' }) + const hasConnection = data.filter(item => item.provider === 'notion') || [] + setHasConnection(hasConnection.length > 0) + } useEffect(() => { checkAPIKey() + checkNotionConnection() }, []) const [detail, setDetail] = useState(null) @@ -62,16 +80,16 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { try { const detail = await fetchDataDetail(datasetId) setDetail(detail) - } catch (e) { + } + catch (e) { setHasError(true) } } })() }, [datasetId]) - if (hasError) { + if (hasError) return - } return (
@@ -80,9 +98,16 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
{step === 1 && } {(step === 2 && (!datasetId || (datasetId && !!detail))) && { onSetting={showSetAPIKey} indexingType={detail?.indexing_technique || ''} datasetId={datasetId} + dataSourceType={dataSourceType} file={file} + notionPages={notionPages} onStepChange={changeStep} updateIndexingTypeCache={updateIndexingTypeCache} updateResultCache={updateResultCache} @@ -106,6 +133,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { await checkAPIKey() hideSetAPIkey() }} />} + {isShowDataSourceSetting && }
) } diff --git a/web/app/components/datasets/create/notion-page-preview/index.module.css b/web/app/components/datasets/create/notion-page-preview/index.module.css new file mode 100644 index 000000000..12d374735 --- /dev/null +++ b/web/app/components/datasets/create/notion-page-preview/index.module.css @@ -0,0 +1,54 @@ +.filePreview { + @apply flex flex-col border-l border-gray-200 shrink-0; + width: 528px; + background-color: #fcfcfd; + } + + .previewHeader { + @apply border-b border-gray-200 shrink-0; + margin: 42px 32px 0; + padding-bottom: 16px; + } + + .previewHeader .title { + display: flex; + justify-content: space-between; + align-items: center; + color: #101828; + font-weight: 600; + font-size: 18px; + line-height: 28px; + } + + .previewHeader .fileName { + display: flex; + align-items: center; + font-weight: 400; + font-size: 12px; + line-height: 18px; + color: #1D2939; + } + + .previewHeader .filetype { + color: #667085; + } + + .previewContent { + @apply overflow-y-auto grow; + padding: 20px 32px; + font-weight: 400; + font-size: 16px; + line-height: 24px; + color: #344054; + } + + .previewContent .loading { + width: 100%; + height: 180px; + background: #f9fafb center no-repeat url(../assets/Loading.svg); + background-size: contain; + } + .fileContent { + white-space: pre-line; + } + \ No newline at end of file diff --git a/web/app/components/datasets/create/notion-page-preview/index.tsx b/web/app/components/datasets/create/notion-page-preview/index.tsx new file mode 100644 index 000000000..4a55c1fdc --- /dev/null +++ b/web/app/components/datasets/create/notion-page-preview/index.tsx @@ -0,0 +1,75 @@ +'use client' +import React, { useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import cn from 'classnames' +import { XMarkIcon } from '@heroicons/react/20/solid' +import s from './index.module.css' +import type { DataSourceNotionPage } from '@/models/common' +import NotionIcon from '@/app/components/base/notion-icon' +import { fetchNotionPagePreview } from '@/service/datasets' + +type Page = DataSourceNotionPage & { workspace_id: string } +type IProps = { + currentPage?: Page + hidePreview: () => void +} + +const NotionPagePreview = ({ + currentPage, + hidePreview, +}: IProps) => { + const { t } = useTranslation() + const [previewContent, setPreviewContent] = useState('') + const [loading, setLoading] = useState(true) + + const getPreviewContent = async () => { + if (!currentPage) + return + try { + const res = await fetchNotionPagePreview({ + workspaceID: currentPage.workspace_id, + pageID: currentPage.page_id, + pageType: currentPage.type, + }) + setPreviewContent(res.content) + setLoading(false) + } + catch {} + } + + useEffect(() => { + if (currentPage) { + setLoading(true) + getPreviewContent() + } + }, [currentPage]) + + return ( +
+
+
+ {t('datasetCreation.stepOne.pagePreview')} +
+ +
+
+
+ + {currentPage?.page_name} +
+
+
+ {loading &&
} + {!loading && ( +
{previewContent}
+ )} +
+
+ ) +} + +export default NotionPagePreview diff --git a/web/app/components/datasets/create/step-one/index.module.css b/web/app/components/datasets/create/step-one/index.module.css index bf391bf8d..f2e2c8523 100644 --- a/web/app/components/datasets/create/step-one/index.module.css +++ b/web/app/components/datasets/create/step-one/index.module.css @@ -107,3 +107,53 @@ background: center no-repeat url(../assets/folder-plus.svg); background-size: contain; } + +.notionConnectionTip { + display: flex; + flex-direction: column; + align-items: flex-start; + padding: 24px; + max-width: 640px; + background: #F9FAFB; + border-radius: 16px; +} + +.notionIcon { + display: flex; + padding: 12px; + width: 48px; + height: 48px; + background: #fff center no-repeat url(../assets/notion.svg); + background-size: 24px; + border: 0.5px solid #EAECF5; + box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03); + border-radius: 12px; +} + +.notionConnectionTip .title { + position: relative; + margin: 24px 0 4px; + font-style: normal; + font-weight: 600; + font-size: 16px; + line-height: 24px; + color: #374151; +} +.notionConnectionTip .title::after { + content: ''; + position: absolute; + top: -6px; + right: -12px; + width: 16px; + height: 16px; + background: center no-repeat url(../assets/Icon-3-dots.svg); + background-size: contain; +} +.notionConnectionTip .tip { + margin-bottom: 20px; + font-style: normal; + font-weight: 400; + font-size: 13px; + line-height: 18px; + color: #6B7280; +} diff --git a/web/app/components/datasets/create/step-one/index.tsx b/web/app/components/datasets/create/step-one/index.tsx index b93fba910..bec16f57d 100644 --- a/web/app/components/datasets/create/step-one/index.tsx +++ b/web/app/components/datasets/create/step-one/index.tsx @@ -1,36 +1,82 @@ 'use client' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' -import type { File } from '@/models/datasets' +import cn from 'classnames' import FilePreview from '../file-preview' import FileUploader from '../file-uploader' +import NotionPagePreview from '../notion-page-preview' import EmptyDatasetCreationModal from '../empty-dataset-creation-modal' -import Button from '@/app/components/base/button' - -import cn from 'classnames' import s from './index.module.css' +import type { File } from '@/models/datasets' +import type { DataSourceNotionPage } from '@/models/common' +import { DataSourceType } from '@/models/datasets' +import Button from '@/app/components/base/button' +import { NotionPageSelector } from '@/app/components/base/notion-page-selector' type IStepOneProps = { - datasetId?: string, - file?: File, - updateFile: (file?: File) => void, - onStepChange: () => void, + datasetId?: string + dataSourceType: DataSourceType + dataSourceTypeDisable: Boolean + hasConnection: boolean + onSetting: () => void + file?: File + updateFile: (file?: File) => void + notionPages?: any[] + updateNotionPages: (value: any[]) => void + onStepChange: () => void + changeType: (type: DataSourceType) => void +} + +type Page = DataSourceNotionPage & { workspace_id: string } + +type NotionConnectorProps = { + onSetting: () => void +} +export const NotionConnector = ({ onSetting }: NotionConnectorProps) => { + const { t } = useTranslation() + + return ( +
+ +
{t('datasetCreation.stepOne.notionSyncTitle')}
+
{t('datasetCreation.stepOne.notionSyncTip')}
+ +
+ ) } const StepOne = ({ datasetId, + dataSourceType, + dataSourceTypeDisable, + changeType, + hasConnection, + onSetting, onStepChange, file, updateFile, + notionPages = [], + updateNotionPages, }: IStepOneProps) => { - const [dataSourceType, setDataSourceType] = useState('FILE') const [showModal, setShowModal] = useState(false) + const [showFilePreview, setShowFilePreview] = useState(true) + const [currentNotionPage, setCurrentNotionPage] = useState() const { t } = useTranslation() + const hidePreview = () => setShowFilePreview(false) + const modalShowHandle = () => setShowModal(true) const modalCloseHandle = () => setShowModal(false) + const updateCurrentPage = (page: Page) => { + setCurrentNotionPage(page) + } + + const hideNotionPagePreview = () => { + setCurrentNotionPage(undefined) + } + return (
@@ -38,41 +84,76 @@ const StepOne = ({
setDataSourceType('FILE')} + className={cn( + s.dataSourceItem, + dataSourceType === DataSourceType.FILE && s.active, + dataSourceTypeDisable && dataSourceType !== DataSourceType.FILE && s.disabled, + )} + onClick={() => { + if (dataSourceTypeDisable) + return + changeType(DataSourceType.FILE) + hidePreview() + }} > - + {t('datasetCreation.stepOne.dataSourceType.file')}
setDataSourceType('notion')} + className={cn( + s.dataSourceItem, + dataSourceType === DataSourceType.NOTION && s.active, + dataSourceTypeDisable && dataSourceType !== DataSourceType.NOTION && s.disabled, + )} + onClick={() => { + if (dataSourceTypeDisable) + return + changeType(DataSourceType.NOTION) + hidePreview() + }} > - Coming soon - + {t('datasetCreation.stepOne.dataSourceType.notion')}
setDataSourceType('web')} + className={cn(s.dataSourceItem, s.disabled, dataSourceType === DataSourceType.WEB && s.active)} + // onClick={() => changeType(DataSourceType.WEB)} > Coming soon - + {t('datasetCreation.stepOne.dataSourceType.web')}
- - + {dataSourceType === DataSourceType.FILE && ( + <> + + + + )} + {dataSourceType === DataSourceType.NOTION && ( + <> + {!hasConnection && } + {hasConnection && ( + <> +
+ page.page_id)} onSelect={updateNotionPages} onPreview={updateCurrentPage} /> +
+ + + )} + + )} {!datasetId && ( <> -
+
{t('datasetCreation.stepOne.emptyDatasetCreation')}
)}
- +
- {file && } + {file && showFilePreview && } + {currentNotionPage && }
) } diff --git a/web/app/components/datasets/create/step-three/index.tsx b/web/app/components/datasets/create/step-three/index.tsx index 04eab14e5..0ca2b0fdc 100644 --- a/web/app/components/datasets/create/step-three/index.tsx +++ b/web/app/components/datasets/create/step-three/index.tsx @@ -1,16 +1,16 @@ 'use client' import React from 'react' import { useTranslation } from 'react-i18next' -import type { createDocumentResponse } from '@/models/datasets' -import EmbeddingDetail from '../../documents/detail/embedding' - import cn from 'classnames' +import EmbeddingProcess from '../embedding-process' + import s from './index.module.css' +import type { FullDocumentDetail, createDocumentResponse } from '@/models/datasets' type StepThreeProps = { - datasetId?: string, - datasetName?: string, - indexingType?: string, + datasetId?: string + datasetName?: string + indexingType?: string creationCache?: createDocumentResponse } @@ -38,12 +38,11 @@ const StepThree = ({ datasetId, datasetName, indexingType, creationCache }: Step
{`${t('datasetCreation.stepThree.additionP1')} ${datasetName || creationCache?.dataset?.name} ${t('datasetCreation.stepThree.additionP2')}`}
)} -
@@ -58,4 +57,4 @@ const StepThree = ({ datasetId, datasetName, indexingType, creationCache }: Step ) } -export default StepThree; +export default StepThree diff --git a/web/app/components/datasets/create/step-two/index.module.css b/web/app/components/datasets/create/step-two/index.module.css index 94820cf9e..7d56b9032 100644 --- a/web/app/components/datasets/create/step-two/index.module.css +++ b/web/app/components/datasets/create/step-two/index.module.css @@ -14,9 +14,26 @@ } .fixed { + padding-top: 12px; + font-size: 12px; + line-height: 18px; background: rgba(255, 255, 255, 0.9); border-bottom: 0.5px solid #EAECF0; backdrop-filter: blur(4px); + animation: fix 0.5s; +} + +@keyframes fix { + from { + padding-top: 42px; + font-size: 18px; + line-height: 28px; + } + to { + padding-top: 12px; + font-size: 12px; + line-height: 18px; + } } .form { @@ -273,11 +290,11 @@ @apply bg-gray-100 caret-primary-600 hover:bg-gray-100 focus:ring-1 focus:ring-inset focus:ring-gray-200 focus-visible:outline-none focus:bg-white placeholder:text-gray-400; } -.file { +.source { @apply flex justify-between items-center mt-8 px-6 py-4 rounded-xl bg-gray-50; } -.file .divider { +.source .divider { @apply shrink-0 mx-4 w-px bg-gray-200; height: 42px; } @@ -318,9 +335,19 @@ .fileIcon.json { background-image: url(../assets/json.svg); } - -.fileContent { - flex: 1 1 50%; +.sourceContent { + flex: 1 1 auto; +} +.sourceCount { + @apply shrink-0 ml-1; + font-weight: 500; + font-size: 13px; + line-height: 18px; + color: #667085; +} +.segmentCount { + flex: 1 1 30%; + max-width: 120px; } .divider { diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 8e6650770..c1e13d773 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -6,9 +6,10 @@ import { useBoolean } from 'ahooks' import { XMarkIcon } from '@heroicons/react/20/solid' import cn from 'classnames' import Link from 'next/link' +import { groupBy } from 'lodash-es' import PreviewItem from './preview-item' import s from './index.module.css' -import type { CreateDocumentReq, File, FullDocumentDetail, FileIndexingEstimateResponse as IndexingEstimateResponse, PreProcessingRule, Rules, createDocumentResponse } from '@/models/datasets' +import type { CreateDocumentReq, File, FullDocumentDetail, FileIndexingEstimateResponse as IndexingEstimateResponse, NotionInfo, PreProcessingRule, Rules, createDocumentResponse } from '@/models/datasets' import { createDocument, createFirstDocument, @@ -20,6 +21,11 @@ import Loading from '@/app/components/base/loading' import Toast from '@/app/components/base/toast' import { formatNumber } from '@/utils/format' +import type { DataSourceNotionPage } from '@/models/common' +import { DataSourceType } from '@/models/datasets' +import NotionIcon from '@/app/components/base/notion-icon' + +type Page = DataSourceNotionPage & { workspace_id: string } type StepTwoProps = { isSetting?: boolean @@ -28,7 +34,9 @@ type StepTwoProps = { onSetting: () => void datasetId?: string indexingType?: string + dataSourceType: DataSourceType file?: File + notionPages?: Page[] onStepChange?: (delta: number) => void updateIndexingTypeCache?: (type: string) => void updateResultCache?: (res: createDocumentResponse) => void @@ -52,7 +60,9 @@ const StepTwo = ({ onSetting, datasetId, indexingType, + dataSourceType, file, + notionPages = [], onStepChange, updateIndexingTypeCache, updateResultCache, @@ -169,12 +179,54 @@ const StepTwo = ({ return processRule } + const getNotionInfo = () => { + const workspacesMap = groupBy(notionPages, 'workspace_id') + const workspaces = Object.keys(workspacesMap).map((workspaceId) => { + return { + workspaceId, + pages: workspacesMap[workspaceId], + } + }) + return workspaces.map((workspace) => { + return { + workspace_id: workspace.workspaceId, + pages: workspace.pages.map((page) => { + const { page_id, page_name, page_icon, type } = page + return { + page_id, + page_name, + page_icon, + type, + } + }), + } + }) as NotionInfo[] + } + const getFileIndexingEstimateParams = () => { - const params = { - file_id: file?.id, - dataset_id: datasetId, - indexing_technique: getIndexing_technique(), - process_rule: getProcessRule(), + let params + if (dataSourceType === DataSourceType.FILE) { + params = { + info_list: { + data_source_type: dataSourceType, + file_info_list: { + // TODO multi files + file_ids: [file?.id || ''], + }, + }, + indexing_technique: getIndexing_technique(), + process_rule: getProcessRule(), + } + } + if (dataSourceType === DataSourceType.NOTION) { + params = { + info_list: { + data_source_type: dataSourceType, + notion_info_list: getNotionInfo(), + }, + indexing_technique: getIndexing_technique(), + process_rule: getProcessRule(), + } } return params } @@ -190,13 +242,22 @@ const StepTwo = ({ else { params = { data_source: { - type: 'upload_file', - info: file?.id, - name: file?.name, + type: dataSourceType, + info_list: { + data_source_type: dataSourceType, + }, }, indexing_technique: getIndexing_technique(), process_rule: getProcessRule(), } as CreateDocumentReq + if (dataSourceType === DataSourceType.FILE) { + params.data_source.info_list.file_info_list = { + // TODO multi files + file_ids: [file?.id || ''], + } + } + if (dataSourceType === DataSourceType.NOTION) + params.data_source.info_list.notion_info_list = getNotionInfo() } return params } @@ -249,9 +310,7 @@ const StepTwo = ({ body: params, }) updateIndexingTypeCache && updateIndexingTypeCache(indexType) - updateResultCache && updateResultCache({ - document: res, - }) + updateResultCache && updateResultCache(res) } onStepChange && onStepChange(+1) isSetting && onSave && onSave() @@ -319,7 +378,6 @@ const StepTwo = ({
{t('datasetCreation.stepTwo.segmentation')}
-
{t('datasetCreation.stepTwo.datasetSettingLink')}
)} -
-
-
{t('datasetCreation.stepTwo.fileName')}
-
- - {getFileName(file?.name || '')} -
+ {/* TODO multi files */} +
+
+ {dataSourceType === DataSourceType.FILE && ( + <> +
{t('datasetCreation.stepTwo.fileSource')}
+
+ + {getFileName(file?.name || '')} +
+ + )} + {dataSourceType === DataSourceType.NOTION && ( + <> +
{t('datasetCreation.stepTwo.notionSource')}
+
+ + {notionPages[0]?.page_name} + {notionPages.length > 1 && ( + + {t('datasetCreation.stepTwo.other')} + {notionPages.length - 1} + {t('datasetCreation.stepTwo.notionUnit')} + + )} +
+ + )}
-
+
{t('datasetCreation.stepTwo.emstimateSegment')}
{ diff --git a/web/app/components/datasets/documents/detail/index.tsx b/web/app/components/datasets/documents/detail/index.tsx index b865a4f4c..3424b355a 100644 --- a/web/app/components/datasets/documents/detail/index.tsx +++ b/web/app/components/datasets/documents/detail/index.tsx @@ -8,15 +8,16 @@ import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' import { omit } from 'lodash-es' import cn from 'classnames' -import Divider from '@/app/components/base/divider' -import Loading from '@/app/components/base/loading' -import { fetchDocumentDetail, MetadataType } from '@/service/datasets' import { OperationAction, StatusItem } from '../list' +import s from '../style.module.css' import Completed from './completed' import Embedding from './embedding' import Metadata from './metadata' -import s from '../style.module.css' import style from './style.module.css' +import Divider from '@/app/components/base/divider' +import Loading from '@/app/components/base/loading' +import type { MetadataType } from '@/service/datasets' +import { fetchDocumentDetail } from '@/service/datasets' export const BackCircleBtn: FC<{ onClick: () => void }> = ({ onClick }) => { return ( @@ -29,11 +30,11 @@ export const BackCircleBtn: FC<{ onClick: () => void }> = ({ onClick }) => { export const DocumentContext = createContext<{ datasetId?: string; documentId?: string }>({}) type DocumentTitleProps = { - extension?: string; - name?: string; - iconCls?: string; - textCls?: string; - wrapperCls?: string; + extension?: string + name?: string + iconCls?: string + textCls?: string + wrapperCls?: string } export const DocumentTitle: FC = ({ extension, name, iconCls, textCls, wrapperCls }) => { @@ -58,15 +59,16 @@ const DocumentDetail: FC = ({ datasetId, documentId }) => { action: 'fetchDocumentDetail', datasetId, documentId, - params: { metadata: 'without' as MetadataType } + params: { metadata: 'without' as MetadataType }, }, apiParams => fetchDocumentDetail(omit(apiParams, 'action'))) const { data: documentMetadata, error: metadataErr, mutate: metadataMutate } = useSWR({ action: 'fetchDocumentDetail', datasetId, documentId, - params: { metadata: 'only' as MetadataType } - }, apiParams => fetchDocumentDetail(omit(apiParams, 'action'))) + params: { metadata: 'only' as MetadataType }, + }, apiParams => fetchDocumentDetail(omit(apiParams, 'action')), + ) const backToPrev = () => { router.push(`/datasets/${datasetId}/documents`) @@ -77,6 +79,13 @@ const DocumentDetail: FC = ({ datasetId, documentId }) => { const embedding = ['queuing', 'indexing', 'paused'].includes((documentDetail?.display_status || '').toLowerCase()) + const handleOperate = (operateName?: string) => { + if (operateName === 'delete') + backToPrev() + else + detailMutate() + } + return (
@@ -90,10 +99,10 @@ const DocumentDetail: FC = ({ datasetId, documentId }) => { detail={{ enabled: documentDetail?.enabled || false, archived: documentDetail?.archived || false, - id: documentId + id: documentId, }} datasetId={datasetId} - onUpdate={detailMutate} + onUpdate={handleOperate} className='!w-[216px]' />
- {isDetailLoading ? : -
+ {isDetailLoading + ? + :
{embedding ? : }
} diff --git a/web/app/components/datasets/documents/detail/settings/index.tsx b/web/app/components/datasets/documents/detail/settings/index.tsx index d1dae06b4..43d9c9491 100644 --- a/web/app/components/datasets/documents/detail/settings/index.tsx +++ b/web/app/components/datasets/documents/detail/settings/index.tsx @@ -1,5 +1,5 @@ 'use client' -import React, { useEffect, useState } from 'react' +import React, { useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' import { useContext } from 'use-context-selector' @@ -43,6 +43,15 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { }, []) const [documentDetail, setDocumentDetail] = useState(null) + const currentPage = useMemo(() => { + return { + workspace_id: documentDetail?.data_source_info.notion_workspace_id, + page_id: documentDetail?.data_source_info.notion_page_id, + page_name: documentDetail?.name, + page_icon: documentDetail?.data_source_info.notion_page_icon, + type: documentDetail?.data_source_info.type, + } + }, [documentDetail]) useEffect(() => { (async () => { try { @@ -71,6 +80,8 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { hasSetAPIKEY={hasSetAPIKEY} onSetting={showSetAPIKey} datasetId={datasetId} + dataSourceType={documentDetail.data_source_type} + notionPages={[currentPage]} indexingType={indexingTechnique || ''} isSetting documentDetail={documentDetail} diff --git a/web/app/components/datasets/documents/index.tsx b/web/app/components/datasets/documents/index.tsx index 66b77d1b4..4a4c21f21 100644 --- a/web/app/components/datasets/documents/index.tsx +++ b/web/app/components/datasets/documents/index.tsx @@ -4,7 +4,7 @@ import React, { useMemo, useState } from 'react' import useSWR from 'swr' import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' -import { debounce, omit } from 'lodash-es' +import { debounce, groupBy, omit } from 'lodash-es' // import Link from 'next/link' import { PlusIcon } from '@heroicons/react/24/solid' import List from './list' @@ -14,7 +14,12 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Pagination from '@/app/components/base/pagination' import { get } from '@/service/base' -import { fetchDocuments } from '@/service/datasets' +import { createDocument, fetchDocuments } from '@/service/datasets' +import { useDatasetDetailContext } from '@/context/dataset-detail' +import { NotionPageSelectorModal } from '@/app/components/base/notion-page-selector' +import type { DataSourceNotionPage } from '@/models/common' +import type { CreateDocumentReq } from '@/models/datasets' +import { DataSourceType } from '@/models/datasets' // Custom page count is not currently supported. const limit = 15 @@ -75,20 +80,63 @@ const Documents: FC = ({ datasetId }) => { const [searchValue, setSearchValue] = useState('') const [currPage, setCurrPage] = React.useState(0) const router = useRouter() + const { dataset } = useDatasetDetailContext() + const [notionPageSelectorModalVisible, setNotionPageSelectorModalVisible] = useState(false) + const [timerCanRun, setTimerCanRun] = useState(true) + const isDataSourceNotion = dataset?.data_source_type === DataSourceType.NOTION const query = useMemo(() => { - return { page: currPage + 1, limit, keyword: searchValue } - }, [searchValue, currPage]) + return { page: currPage + 1, limit, keyword: searchValue, fetch: isDataSourceNotion ? true : '' } + }, [searchValue, currPage, isDataSourceNotion]) - const { data: documentsRes, error, mutate } = useSWR({ - action: 'fetchDocuments', - datasetId, - params: query, - }, apiParams => fetchDocuments(omit(apiParams, 'action'))) + const { data: documentsRes, error, mutate } = useSWR( + { + action: 'fetchDocuments', + datasetId, + params: query, + }, + apiParams => fetchDocuments(omit(apiParams, 'action')), + { refreshInterval: (isDataSourceNotion && timerCanRun) ? 2500 : 0 }, + ) + const documentsWithProgress = useMemo(() => { + let completedNum = 0 + let percent = 0 + const documentsData = documentsRes?.data?.map((documentItem) => { + const { indexing_status, completed_segments, total_segments } = documentItem + const isEmbeddinged = indexing_status === 'completed' || indexing_status === 'paused' || indexing_status === 'error' + + if (isEmbeddinged) + completedNum++ + + const completedCount = completed_segments || 0 + const totalCount = total_segments || 0 + if (totalCount === 0 && completedCount === 0) { + percent = isEmbeddinged ? 100 : 0 + } + else { + const per = Math.round(completedCount * 100 / totalCount) + percent = per > 100 ? 100 : per + } + return { + ...documentItem, + percent, + } + }) + if (completedNum === documentsRes?.data?.length) + setTimerCanRun(false) + return { + ...documentsRes, + data: documentsData, + } + }, [documentsRes]) const total = documentsRes?.total || 0 const routeToDocCreate = () => { + if (isDataSourceNotion) { + setNotionPageSelectorModalVisible(true) + return + } router.push(`/datasets/${datasetId}/documents/create`) } @@ -96,6 +144,54 @@ const Documents: FC = ({ datasetId }) => { const isLoading = !documentsRes && !error + const handleSaveNotionPageSelected = async (selectedPages: (DataSourceNotionPage & { workspace_id: string })[]) => { + const workspacesMap = groupBy(selectedPages, 'workspace_id') + const workspaces = Object.keys(workspacesMap).map((workspaceId) => { + return { + workspaceId, + pages: workspacesMap[workspaceId], + } + }) + const params = { + data_source: { + type: dataset?.data_source_type, + info_list: { + data_source_type: dataset?.data_source_type, + notion_info_list: workspaces.map((workspace) => { + return { + workspace_id: workspace.workspaceId, + pages: workspace.pages.map((page) => { + const { page_id, page_name, page_icon, type } = page + return { + page_id, + page_name, + page_icon, + type, + } + }), + } + }), + }, + }, + indexing_technique: dataset?.indexing_technique, + process_rule: { + rules: {}, + mode: 'automatic', + }, + } as CreateDocumentReq + + await createDocument({ + datasetId, + body: params, + }) + mutate() + setTimerCanRun(true) + // mutateDatasetIndexingStatus(undefined, { revalidate: true }) + setNotionPageSelectorModalVisible(false) + } + + const documentsList = isDataSourceNotion ? documentsWithProgress?.data : documentsRes?.data + return (
@@ -113,19 +209,29 @@ const Documents: FC = ({ datasetId }) => { />
{isLoading ? : total > 0 - ? + ? : } {/* Show Pagination only if the total is more than the limit */} {(total && total > limit) ? : null} + setNotionPageSelectorModalVisible(false)} + onSave={handleSaveNotionPageSelected} + datasetId={dataset?.id || ''} + />
) diff --git a/web/app/components/datasets/documents/list.tsx b/web/app/components/datasets/documents/list.tsx index 8cb7df253..db192a269 100644 --- a/web/app/components/datasets/documents/list.tsx +++ b/web/app/components/datasets/documents/list.tsx @@ -22,8 +22,10 @@ import type { IndicatorProps } from '@/app/components/header/indicator' import Indicator from '@/app/components/header/indicator' import { asyncRunSafe } from '@/utils' import { formatNumber } from '@/utils/format' -import { archiveDocument, deleteDocument, disableDocument, enableDocument } from '@/service/datasets' -import type { DocumentDisplayStatus, DocumentListResponse } from '@/models/datasets' +import { archiveDocument, deleteDocument, disableDocument, enableDocument, syncDocument } from '@/service/datasets' +import NotionIcon from '@/app/components/base/notion-icon' +import ProgressBar from '@/app/components/base/progress-bar' +import { DataSourceType, type DocumentDisplayStatus, type SimpleDocumentDetail } from '@/models/datasets' import type { CommonResponse } from '@/models/common' export const SettingsIcon: FC<{ className?: string }> = ({ className }) => { @@ -32,6 +34,12 @@ export const SettingsIcon: FC<{ className?: string }> = ({ className }) => { } +export const SyncIcon: FC<{ className?: string }> = () => { + return + + +} + export const FilePlusIcon: FC<{ className?: string }> = ({ className }) => { return @@ -77,7 +85,7 @@ export const StatusItem: FC<{
} -type OperationName = 'delete' | 'archive' | 'enable' | 'disable' +type OperationName = 'delete' | 'archive' | 'enable' | 'disable' | 'sync' // operation action for list and detail export const OperationAction: FC<{ @@ -85,13 +93,14 @@ export const OperationAction: FC<{ enabled: boolean archived: boolean id: string + data_source_type: string } datasetId: string - onUpdate: () => void + onUpdate: (operationName?: string) => void scene?: 'list' | 'detail' className?: string }> = ({ datasetId, detail, onUpdate, scene = 'list', className = '' }) => { - const { id, enabled = false, archived = false } = detail || {} + const { id, enabled = false, archived = false, data_source_type } = detail || {} const [showModal, setShowModal] = useState(false) const { notify } = useContext(ToastContext) const { t } = useTranslation() @@ -111,6 +120,9 @@ export const OperationAction: FC<{ case 'disable': opApi = disableDocument break + case 'sync': + opApi = syncDocument + break default: opApi = deleteDocument break @@ -120,7 +132,7 @@ export const OperationAction: FC<{ notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) else notify({ type: 'error', message: t('common.actionMsg.modificationFailed') }) - onUpdate() + onUpdate(operationName) } return
{t('datasetDocuments.list.action.settings')}
- {/*
router.push(`/datasets/${datasetId}/documents/create`)}> - - {t('datasetDocuments.list.action.uploadFile')} -
*/} + { + data_source_type === 'notion_import' && ( +
onOperate('sync')}> + + {t('datasetDocuments.list.action.sync')} +
+ ) + } )} @@ -236,8 +252,9 @@ const renderCount = (count: number | undefined) => { return `${formatNumber((count / 1000).toFixed(1))}k` } +type LocalDoc = SimpleDocumentDetail & { percent?: number } type IDocumentListProps = { - documents: DocumentListResponse['data'] + documents: LocalDoc[] datasetId: string onUpdate: () => void } @@ -248,7 +265,7 @@ type IDocumentListProps = { const DocumentList: FC = ({ documents = [], datasetId, onUpdate }) => { const { t } = useTranslation() const router = useRouter() - const [localDocs, setLocalDocs] = useState(documents) + const [localDocs, setLocalDocs] = useState(documents) const [enableSort, setEnableSort] = useState(false) useEffect(() => { @@ -296,8 +313,16 @@ const DocumentList: FC = ({ documents = [], datasetId, onUpd }}> {doc.position} -
- {doc?.name?.replace(/\.[^/.]+$/, '')}.{suffix} + { + doc?.data_source_type === DataSourceType.NOTION + ? + :
+ } + { + doc.data_source_type === DataSourceType.NOTION + ? {doc.name} + : {doc?.name?.replace(/\.[^/.]+$/, '')}.{suffix} + } {renderCount(doc.word_count)} {renderCount(doc.hit_count)} @@ -305,12 +330,16 @@ const DocumentList: FC = ({ documents = [], datasetId, onUpd {dayjs.unix(doc.created_at).format(t('datasetHitTesting.dateTimeFormat') as string)} - + { + (['indexing', 'splitting', 'parsing', 'cleaning'].includes(doc.indexing_status) && doc?.data_source_type === DataSourceType.NOTION) + ? + : + } diff --git a/web/app/components/datasets/documents/style.module.css b/web/app/components/datasets/documents/style.module.css index 55b2705b1..df8f09c74 100644 --- a/web/app/components/datasets/documents/style.module.css +++ b/web/app/components/datasets/documents/style.module.css @@ -75,6 +75,9 @@ .markdownIcon { background-image: url(./assets/md.svg); } +.mdIcon { + background-image: url(./assets/md.svg); +} .xlsIcon { background-image: url(./assets/xlsx.svg); } diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index b348ca1ec..b1dd6382c 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -9,9 +9,9 @@ import { Menu, Transition } from '@headlessui/react' import Indicator from '../indicator' import AccountSetting from '../account-setting' import AccountAbout from '../account-about' +import WorkplaceSelector from './workplace-selector' import type { LangGeniusVersionResponse, UserProfileResponse } from '@/models/common' import I18n from '@/context/i18n' -import WorkplaceSelector from './workplace-selector' import Avatar from '@/app/components/base/avatar' type IAppSelectorProps = { diff --git a/web/app/components/header/account-setting/data-source-page/data-source-notion/index.tsx b/web/app/components/header/account-setting/data-source-page/data-source-notion/index.tsx new file mode 100644 index 000000000..276df6a15 --- /dev/null +++ b/web/app/components/header/account-setting/data-source-page/data-source-notion/index.tsx @@ -0,0 +1,102 @@ +import { useTranslation } from 'react-i18next' +import Link from 'next/link' +import { PlusIcon } from '@heroicons/react/24/solid' +import cn from 'classnames' +import Indicator from '../../../indicator' +import Operate from './operate' +import s from './style.module.css' +import NotionIcon from '@/app/components/base/notion-icon' +import { apiPrefix } from '@/config' +import type { DataSourceNotion as TDataSourceNotion } from '@/models/common' + +type DataSourceNotionProps = { + workspaces: TDataSourceNotion[] +} +const DataSourceNotion = ({ + workspaces, +}: DataSourceNotionProps) => { + const { t } = useTranslation() + const connected = !!workspaces.length + + return ( +
+
+
+
+
+ {t('common.dataSource.notion.title')} +
+ { + !connected && ( +
+ {t('common.dataSource.notion.description')} +
+ ) + } +
+ { + !connected + ? ( + + {t('common.dataSource.connect')} + + ) + : ( + + + {t('common.dataSource.notion.addWorkspace')} + + ) + } +
+ { + connected && ( +
+
+ {t('common.dataSource.notion.connectedWorkspace')} +
+
+
+ ) + } + { + connected && ( +
+ { + workspaces.map(workspace => ( +
+ +
{workspace.source_info.workspace_name}
+ { + workspace.is_bound + ? + : + } +
+ { + workspace.is_bound + ? t('common.dataSource.notion.connected') + : t('common.dataSource.notion.disconnected') + } +
+
+ +
+ )) + } +
+ ) + } +
+ ) +} + +export default DataSourceNotion diff --git a/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.module.css b/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.module.css new file mode 100644 index 000000000..60f89a720 --- /dev/null +++ b/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.module.css @@ -0,0 +1,14 @@ +.file-icon { + background: url(../../../../assets/file.svg) center center no-repeat; + background-size: contain; +} + +.sync-icon { + background: url(../../../../assets/sync.svg) center center no-repeat; + background-size: contain; +} + +.trash-icon { + background: url(../../../../assets/trash.svg) center center no-repeat; + background-size: contain; +} \ No newline at end of file diff --git a/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.tsx b/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.tsx new file mode 100644 index 000000000..cd96f379f --- /dev/null +++ b/web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.tsx @@ -0,0 +1,107 @@ +'use client' +import { useTranslation } from 'react-i18next' +import { Fragment } from 'react' +import Link from 'next/link' +import { useSWRConfig } from 'swr' +import { EllipsisHorizontalIcon } from '@heroicons/react/24/solid' +import { Menu, Transition } from '@headlessui/react' +import cn from 'classnames' +import s from './index.module.css' +import { apiPrefix } from '@/config' +import { syncDataSourceNotion, updateDataSourceNotionAction } from '@/service/common' +import Toast from '@/app/components/base/toast' +import type { DataSourceNotion } from '@/models/common' + +type OperateProps = { + workspace: DataSourceNotion +} +export default function Operate({ + workspace, +}: OperateProps) { + const itemClassName = ` + flex px-3 py-2 hover:bg-gray-50 text-sm text-gray-700 + cursor-pointer + ` + const itemIconClassName = ` + mr-2 mt-[2px] w-4 h-4 + ` + const { t } = useTranslation() + const { mutate } = useSWRConfig() + + const updateIntegrates = () => { + Toast.notify({ + type: 'success', + message: t('common.api.success'), + }) + mutate({ url: 'data-source/integrates' }) + } + const handleSync = async () => { + await syncDataSourceNotion({ url: `/oauth/data-source/notion/${workspace.id}/sync` }) + updateIntegrates() + } + const handleRemove = async () => { + await updateDataSourceNotionAction({ url: `/data-source/integrates/${workspace.id}/disable` }) + updateIntegrates() + } + + return ( + + { + ({ open }) => ( + <> + + + + + +
+ + +
+
+
{t('common.dataSource.notion.changeAuthorizedPages')}
+
+ {workspace.source_info.total} {t('common.dataSource.notion.pagesAuthorized')} +
+
+ +
+ +
+
+
{t('common.dataSource.notion.sync')}
+
+ +
+ +
+
+
+
{t('common.dataSource.notion.remove')}
+
+
+ + + + + ) + } +
+ ) +} diff --git a/web/app/components/header/account-setting/data-source-page/data-source-notion/style.module.css b/web/app/components/header/account-setting/data-source-page/data-source-notion/style.module.css new file mode 100644 index 000000000..ede323072 --- /dev/null +++ b/web/app/components/header/account-setting/data-source-page/data-source-notion/style.module.css @@ -0,0 +1,12 @@ +.notion-icon { + background: #ffffff url(../../../assets/notion.svg) center center no-repeat; + background-size: 20px 20px; +} + +.workspace-item { + box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); +} + +.workspace-item:last-of-type { + margin-bottom: 0; +} \ No newline at end of file diff --git a/web/app/components/header/account-setting/data-source-page/index.module.css b/web/app/components/header/account-setting/data-source-page/index.module.css new file mode 100644 index 000000000..e69de29bb diff --git a/web/app/components/header/account-setting/data-source-page/index.tsx b/web/app/components/header/account-setting/data-source-page/index.tsx new file mode 100644 index 000000000..761d9cbfe --- /dev/null +++ b/web/app/components/header/account-setting/data-source-page/index.tsx @@ -0,0 +1,17 @@ +import useSWR from 'swr' +import { useTranslation } from 'react-i18next' +import DataSourceNotion from './data-source-notion' +import { fetchDataSource } from '@/service/common' + +export default function DataSourcePage() { + const { t } = useTranslation() + const { data } = useSWR({ url: 'data-source/integrates' }, fetchDataSource) + const notionWorkspaces = data?.data.filter(item => item.provider === 'notion') || [] + + return ( +
+
{t('common.dataSource.add')}
+ +
+ ) +} diff --git a/web/app/components/header/account-setting/index.module.css b/web/app/components/header/account-setting/index.module.css index e7d1ef61e..f97483674 100644 --- a/web/app/components/header/account-setting/index.module.css +++ b/web/app/components/header/account-setting/index.module.css @@ -2,4 +2,14 @@ max-width: 720px !important; padding: 0 !important; overflow-y: auto; +} + +.data-source-icon { + background: url(../assets/data-source.svg) center center no-repeat; + background-size: cover; +} + +.data-source-solid-icon { + background: url(../assets/data-source-blue.svg) center center no-repeat; + background-size: cover; } \ No newline at end of file diff --git a/web/app/components/header/account-setting/index.tsx b/web/app/components/header/account-setting/index.tsx index 7689c24cb..84641b92c 100644 --- a/web/app/components/header/account-setting/index.tsx +++ b/web/app/components/header/account-setting/index.tsx @@ -1,20 +1,32 @@ 'use client' import { useTranslation } from 'react-i18next' import { useState } from 'react' -import { AtSymbolIcon, GlobeAltIcon, UserIcon, XMarkIcon, CubeTransparentIcon, UsersIcon } from '@heroicons/react/24/outline' +import { AtSymbolIcon, CubeTransparentIcon, GlobeAltIcon, UserIcon, UsersIcon, XMarkIcon } from '@heroicons/react/24/outline' import { GlobeAltIcon as GlobalAltIconSolid, UserIcon as UserIconSolid, UsersIcon as UsersIconSolid } from '@heroicons/react/24/solid' +import cn from 'classnames' import AccountPage from './account-page' import MembersPage from './members-page' import IntegrationsPage from './Integrations-page' import LanguagePage from './language-page' import ProviderPage from './provider-page' +import DataSourcePage from './data-source-page' import s from './index.module.css' import Modal from '@/app/components/base/modal' const iconClassName = ` - w-[18px] h-[18px] ml-3 mr-2 + w-4 h-4 ml-3 mr-2 ` +type IconProps = { + className?: string +} +const DataSourceIcon = ({ className }: IconProps) => ( +
+) +const DataSourceSolidIcon = ({ className }: IconProps) => ( +
+) + type IAccountSettingProps = { onCancel: () => void activeTab?: string @@ -48,7 +60,7 @@ export default function AccountSetting({ icon: , activeIcon: , }, - ] + ], }, { key: 'workspace-group', @@ -66,8 +78,14 @@ export default function AccountSetting({ icon: , activeIcon: , }, - ] - } + { + key: 'data-source', + name: t('common.settings.dataSource'), + icon: , + activeIcon: , + }, + ], + }, ] return ( @@ -126,6 +144,9 @@ export default function AccountSetting({ { activeMenu === 'provider' && } + { + activeMenu === 'data-source' && + }
diff --git a/web/app/components/header/account-setting/members-page/index.tsx b/web/app/components/header/account-setting/members-page/index.tsx index e78328972..4242b7415 100644 --- a/web/app/components/header/account-setting/members-page/index.tsx +++ b/web/app/components/header/account-setting/members-page/index.tsx @@ -1,19 +1,19 @@ 'use client' import { useState } from 'react' -import s from './index.module.css' import cn from 'classnames' import useSWR from 'swr' import dayjs from 'dayjs' import 'dayjs/locale/zh-cn' import relativeTime from 'dayjs/plugin/relativeTime' -import I18n from '@/context/i18n' import { useContext } from 'use-context-selector' -import { fetchMembers } from '@/service/common' import { UserPlusIcon } from '@heroicons/react/24/outline' import { useTranslation } from 'react-i18next' +import s from './index.module.css' import InviteModal from './invite-modal' import InvitedModal from './invited-modal' import Operation from './operation' +import { fetchMembers } from '@/service/common' +import I18n from '@/context/i18n' import { useAppContext } from '@/context/app-context' import Avatar from '@/app/components/base/avatar' import { useWorkspacesContext } from '@/context/workspace-context' @@ -35,18 +35,18 @@ const MembersPage = () => { const owner = accounts.filter(account => account.role === 'owner')?.[0]?.email === userProfile.email const { workspaces } = useWorkspacesContext() const currentWrokspace = workspaces.filter(item => item.current)?.[0] - + return ( <>
-
{currentWrokspace.name}
+
{currentWrokspace?.name}
{t('common.userProfile.workspace')}
setInviteModalVisible(true)}> @@ -78,10 +78,10 @@ const MembersPage = () => {
{dayjs(Number((account.last_login_at || account.created_at)) * 1000).locale(locale === 'zh-Hans' ? 'zh-cn' : 'en').fromNow()}
{ - owner && account.role !== 'owner' + (owner && account.role !== 'owner') ? mutate()} /> :
{RoleMap[account.role] || RoleMap.normal}
- } + }
)) @@ -111,4 +111,4 @@ const MembersPage = () => { ) } -export default MembersPage \ No newline at end of file +export default MembersPage diff --git a/web/app/components/header/assets/data-source-blue.svg b/web/app/components/header/assets/data-source-blue.svg new file mode 100644 index 000000000..461b2c0b1 --- /dev/null +++ b/web/app/components/header/assets/data-source-blue.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/header/assets/data-source.svg b/web/app/components/header/assets/data-source.svg new file mode 100644 index 000000000..a2c11b4d4 --- /dev/null +++ b/web/app/components/header/assets/data-source.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/header/assets/file.svg b/web/app/components/header/assets/file.svg new file mode 100644 index 000000000..aab093a06 --- /dev/null +++ b/web/app/components/header/assets/file.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/header/assets/notion.svg b/web/app/components/header/assets/notion.svg new file mode 100644 index 000000000..eeda89492 --- /dev/null +++ b/web/app/components/header/assets/notion.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/web/app/components/header/assets/sync.svg b/web/app/components/header/assets/sync.svg new file mode 100644 index 000000000..795077a41 --- /dev/null +++ b/web/app/components/header/assets/sync.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/header/assets/trash.svg b/web/app/components/header/assets/trash.svg new file mode 100644 index 000000000..00c29894f --- /dev/null +++ b/web/app/components/header/assets/trash.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/components/header/nav/nav-selector/index.tsx b/web/app/components/header/nav/nav-selector/index.tsx index 27e0931a5..15e576228 100644 --- a/web/app/components/header/nav/nav-selector/index.tsx +++ b/web/app/components/header/nav/nav-selector/index.tsx @@ -24,7 +24,7 @@ export type INavSelectorProps = { const itemClassName = ` flex items-center w-full h-10 px-3 text-gray-700 text-[14px] - rounded-lg font-normal hover:bg-gray-100 cursor-pointer + rounded-lg font-normal hover:bg-gray-100 cursor-pointer truncate ` const NavSelector = ({ curNav, navs, createText, onCreate, onLoadmore }: INavSelectorProps) => { @@ -50,9 +50,9 @@ const NavSelector = ({ curNav, navs, createText, onCreate, onLoadmore }: INavSel text-[#1C64F2] hover:bg-[#EBF5FF] " > - {curNav?.name} +
{curNav?.name}