diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index acb226530..8c429044d 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,6 +1,5 @@ import uuid -import pandas as pd from flask import request from flask_login import current_user from flask_restful import Resource, marshal, reqparse @@ -14,8 +13,6 @@ from controllers.console.datasets.error import ( ChildChunkDeleteIndexError, ChildChunkIndexingError, InvalidActionError, - NoFileUploadedError, - TooManyFilesError, ) from controllers.console.wraps import ( account_initialization_required, @@ -32,6 +29,7 @@ from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields from libs.login import login_required from models.dataset import ChildChunk, DocumentSegment +from models.model import UploadFile from services.dataset_service import DatasetService, DocumentService, SegmentService from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError @@ -365,37 +363,28 @@ class DatasetDocumentSegmentBatchImportApi(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - # get file from request - file = request.files["file"] - # check file - if "file" not in request.files: - raise NoFileUploadedError() - if len(request.files) > 1: - raise TooManyFilesError() + parser = reqparse.RequestParser() + parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + upload_file_id = args["upload_file_id"] + + upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + if not upload_file: + raise NotFound("UploadFile not found.") + # check file type - if not file.filename or not file.filename.lower().endswith(".csv"): + if not upload_file.name or not upload_file.name.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: - # Skip the first row - df = pd.read_csv(file) - result = [] - for index, row in df.iterrows(): - if document.doc_form == "qa_model": - data = {"content": row.iloc[0], "answer": row.iloc[1]} - else: - data = {"content": row.iloc[0]} - result.append(data) - if len(result) == 0: - raise ValueError("The CSV file is empty.") # async job job_id = str(uuid.uuid4()) indexing_cache_key = f"segment_batch_import_{str(job_id)}" # send batch add segments task redis_client.setnx(indexing_cache_key, "waiting") batch_create_segment_to_index_task.delay( - str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id + str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id ) except Exception as e: return {"error": str(e)}, 500 diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index d72e35029..714e30acc 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -1,9 +1,12 @@ import datetime import logging +import tempfile import time import uuid +from pathlib import Path import click +import pandas as pd from celery import shared_task # type: ignore from sqlalchemy import func from sqlalchemy.orm import Session @@ -12,15 +15,17 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client +from extensions.ext_storage import storage from libs import helper from models.dataset import Dataset, Document, DocumentSegment +from models.model import UploadFile from services.vector_service import VectorService @shared_task(queue="dataset") def batch_create_segment_to_index_task( job_id: str, - content: list, + upload_file_id: str, dataset_id: str, document_id: str, tenant_id: str, @@ -29,13 +34,13 @@ def batch_create_segment_to_index_task( """ Async batch create segment to index :param job_id: - :param content: + :param upload_file_id: :param dataset_id: :param document_id: :param tenant_id: :param user_id: - Usage: batch_create_segment_to_index_task.delay(job_id, content, dataset_id, document_id, tenant_id, user_id) + Usage: batch_create_segment_to_index_task.delay(job_id, upload_file_id, dataset_id, document_id, tenant_id, user_id) """ logging.info(click.style(f"Start batch create segment jobId: {job_id}", fg="green")) start_at = time.perf_counter() @@ -58,6 +63,29 @@ def batch_create_segment_to_index_task( or dataset_document.indexing_status != "completed" ): raise ValueError("Document is not available.") + + upload_file = session.get(UploadFile, upload_file_id) + if not upload_file: + raise ValueError("UploadFile not found.") + + with tempfile.TemporaryDirectory() as temp_dir: + suffix = Path(upload_file.key).suffix + # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore + storage.download(upload_file.key, file_path) + + # Skip the first row + df = pd.read_csv(file_path) + content = [] + for index, row in df.iterrows(): + if dataset_document.doc_form == "qa_model": + data = {"content": row.iloc[0], "answer": row.iloc[1]} + else: + data = {"content": row.iloc[0]} + content.append(data) + if len(content) == 0: + raise ValueError("The CSV file is empty.") + document_segments = [] embedding_model = None if dataset.indexing_technique == "high_quality": diff --git a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx index c2224296d..c352f11d7 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import React, { useEffect, useRef, useState } from 'react' +import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { RiDeleteBinLine, } from '@remixicon/react' @@ -10,10 +10,17 @@ import cn from '@/utils/classnames' import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' import { ToastContext } from '@/app/components/base/toast' import Button from '@/app/components/base/button' +import type { FileItem } from '@/models/datasets' +import { upload } from '@/service/base' +import useSWR from 'swr' +import { fetchFileUploadConfig } from '@/service/common' +import SimplePieChart from '@/app/components/base/simple-pie-chart' +import { Theme } from '@/types/app' +import useTheme from '@/hooks/use-theme' export type Props = { - file: File | undefined - updateFile: (file?: File) => void + file: FileItem | undefined + updateFile: (file?: FileItem) => void } const CSVUploader: FC = ({ @@ -26,6 +33,68 @@ const CSVUploader: FC = ({ const dropRef = useRef(null) const dragRef = useRef(null) const fileUploader = useRef(null) + const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) + const fileUploadConfig = useMemo(() => fileUploadConfigResponse ?? { + file_size_limit: 15, + }, [fileUploadConfigResponse]) + + const fileUpload = useCallback(async (fileItem: FileItem): Promise => { + fileItem.progress = 0 + + const formData = new FormData() + formData.append('file', fileItem.file) + const onProgress = (e: ProgressEvent) => { + if (e.lengthComputable) { + const progress = Math.floor(e.loaded / e.total * 100) + updateFile({ + ...fileItem, + progress, + }) + } + } + + return upload({ + xhr: new XMLHttpRequest(), + data: formData, + onprogress: onProgress, + }, false, undefined, '?source=datasets') + .then((res: File) => { + const completeFile = { + fileID: fileItem.fileID, + file: res, + progress: 100, + } + updateFile(completeFile) + return Promise.resolve({ ...completeFile }) + }) + .catch((e) => { + notify({ type: 'error', message: e?.response?.code === 'forbidden' ? e?.response?.message : t('datasetCreation.stepOne.uploader.failed') }) + const errorFile = { + ...fileItem, + progress: -2, + } + updateFile(errorFile) + return Promise.resolve({ ...errorFile }) + }) + .finally() + }, [notify, t, updateFile]) + + const uploadFile = useCallback(async (fileItem: FileItem) => { + await fileUpload(fileItem) + }, [fileUpload]) + + const initialUpload = useCallback((file?: File) => { + if (!file) + return false + + const newFile: FileItem = { + fileID: `file0-${Date.now()}`, + file, + progress: -1, + } + updateFile(newFile) + uploadFile(newFile) + }, [updateFile, uploadFile]) const handleDragEnter = (e: DragEvent) => { e.preventDefault() @@ -52,7 +121,7 @@ const CSVUploader: FC = ({ notify({ type: 'error', message: t('datasetCreation.stepOne.uploader.validation.count') }) return } - updateFile(files[0]) + initialUpload(files[0]) } const selectHandle = () => { if (fileUploader.current) @@ -63,11 +132,43 @@ const CSVUploader: FC = ({ fileUploader.current.value = '' updateFile() } + + const getFileType = (currentFile: File) => { + if (!currentFile) + return '' + + const arr = currentFile.name.split('.') + return arr[arr.length - 1] + } + + const isValid = useCallback((file?: File) => { + if (!file) + return false + + const { size } = file + const ext = `.${getFileType(file)}` + const isValidType = ext.toLowerCase() === '.csv' + if (!isValidType) + notify({ type: 'error', message: t('datasetCreation.stepOne.uploader.validation.typeError') }) + + const isValidSize = size <= fileUploadConfig.file_size_limit * 1024 * 1024 + if (!isValidSize) + notify({ type: 'error', message: t('datasetCreation.stepOne.uploader.validation.size', { size: fileUploadConfig.file_size_limit }) }) + + return isValidType && isValidSize + }, [fileUploadConfig, notify, t]) + const fileChangeHandle = (e: React.ChangeEvent) => { const currentFile = e.target.files?.[0] - updateFile(currentFile) + if (!isValid(currentFile)) + return + + initialUpload(currentFile) } + const { theme } = useTheme() + const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme]) + useEffect(() => { dropRef.current?.addEventListener('dragenter', handleDragEnter) dropRef.current?.addEventListener('dragover', handleDragOver) @@ -108,10 +209,16 @@ const CSVUploader: FC = ({
- {file.name.replace(/.csv$/, '')} + {file.file.name.replace(/.csv$/, '')} .csv
+ {(file.progress < 100 && file.progress >= 0) && ( + <> + +
+ + )}
diff --git a/web/app/components/datasets/documents/detail/batch-modal/index.tsx b/web/app/components/datasets/documents/detail/batch-modal/index.tsx index 614471c56..0952a823b 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/index.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/index.tsx @@ -7,14 +7,14 @@ import CSVUploader from './csv-uploader' import CSVDownloader from './csv-downloader' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' -import type { ChunkingMode } from '@/models/datasets' +import type { ChunkingMode, FileItem } from '@/models/datasets' import { noop } from 'lodash-es' export type IBatchModalProps = { isShow: boolean docForm: ChunkingMode onCancel: () => void - onConfirm: (file: File) => void + onConfirm: (file: FileItem) => void } const BatchModal: FC = ({ @@ -24,8 +24,8 @@ const BatchModal: FC = ({ onConfirm, }) => { const { t } = useTranslation() - const [currentCSV, setCurrentCSV] = useState() - const handleFile = (file?: File) => setCurrentCSV(file) + const [currentCSV, setCurrentCSV] = useState() + const handleFile = (file?: FileItem) => setCurrentCSV(file) const handleSend = () => { if (!currentCSV) @@ -56,7 +56,7 @@ const BatchModal: FC = ({ -
diff --git a/web/app/components/datasets/documents/detail/index.tsx b/web/app/components/datasets/documents/detail/index.tsx index aff74038e..79d12e47e 100644 --- a/web/app/components/datasets/documents/detail/index.tsx +++ b/web/app/components/datasets/documents/detail/index.tsx @@ -17,7 +17,7 @@ import cn from '@/utils/classnames' import Divider from '@/app/components/base/divider' import Loading from '@/app/components/base/loading' import { ToastContext } from '@/app/components/base/toast' -import type { ChunkingMode, ParentMode, ProcessMode } from '@/models/datasets' +import type { ChunkingMode, FileItem, ParentMode, ProcessMode } from '@/models/datasets' import { useDatasetDetailContext } from '@/context/dataset-detail' import FloatRightContainer from '@/app/components/base/float-right-container' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' @@ -111,12 +111,10 @@ const DocumentDetail: FC = ({ datasetId, documentId }) => { } const { mutateAsync: segmentBatchImport } = useSegmentBatchImport() - const runBatch = async (csv: File) => { - const formData = new FormData() - formData.append('file', csv) + const runBatch = async (csv: FileItem) => { await segmentBatchImport({ url: `/datasets/${datasetId}/documents/${documentId}/segments/batch_import`, - body: formData, + body: { upload_file_id: csv.file.id! }, }, { onSuccess: (res) => { setImportStatus(res.job_status) diff --git a/web/service/knowledge/use-segment.ts b/web/service/knowledge/use-segment.ts index ca1778fb9..8b3e939e7 100644 --- a/web/service/knowledge/use-segment.ts +++ b/web/service/knowledge/use-segment.ts @@ -154,9 +154,9 @@ export const useUpdateChildSegment = () => { export const useSegmentBatchImport = () => { return useMutation({ mutationKey: [NAME_SPACE, 'batchImport'], - mutationFn: (payload: { url: string; body: FormData }) => { + mutationFn: (payload: { url: string; body: { upload_file_id: string } }) => { const { url, body } = payload - return post(url, { body }, { bodyStringify: false, deleteContentType: true }) + return post(url, { body }) }, }) }