diff --git a/api/core/index/spiltter/fixed_text_splitter.py b/api/core/index/spiltter/fixed_text_splitter.py new file mode 100644 index 000000000..aaaf8e5a1 --- /dev/null +++ b/api/core/index/spiltter/fixed_text_splitter.py @@ -0,0 +1,68 @@ +"""Functionality for splitting text.""" +from __future__ import annotations + +from typing import ( + Any, + List, + Optional, +) + +from langchain.text_splitter import RecursiveCharacterTextSplitter + + +class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): + def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any): + """Create a new TextSplitter.""" + super().__init__(**kwargs) + self._fixed_separator = fixed_separator + self._separators = separators or ["\n\n", "\n", " ", ""] + + def split_text(self, text: str) -> List[str]: + """Split incoming text and return chunks.""" + if self._fixed_separator: + chunks = text.split(self._fixed_separator) + else: + chunks = list(text) + + final_chunks = [] + for chunk in chunks: + if self._length_function(chunk) > self._chunk_size: + final_chunks.extend(self.recursive_split_text(chunk)) + else: + final_chunks.append(chunk) + + return final_chunks + + def recursive_split_text(self, text: str) -> List[str]: + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = self._separators[-1] + for _s in self._separators: + if _s == "": + separator = _s + break + if _s in text: + separator = _s + break + # Now that we have the separator, split the text + if separator: + splits = text.split(separator) + else: + splits = list(text) + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + for s in splits: + if self._length_function(s) < self._chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, separator) + final_chunks.extend(merged_text) + _good_splits = [] + other_info = self.recursive_split_text(s) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, separator) + final_chunks.extend(merged_text) + return final_chunks diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index fd9e43011..f06f3a003 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -18,6 +18,7 @@ from core.docstore.dataset_docstore import DatesetDocumentStore from core.index.keyword_table_index import KeywordTableIndex from core.index.readers.html_parser import HTMLParser from core.index.readers.pdf_parser import PDFParser +from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter from core.index.vector_index import VectorIndex from core.llm.token_calculator import TokenCalculator from extensions.ext_database import db @@ -267,16 +268,14 @@ class IndexingRunner: raise ValueError("Custom segment length should be between 50 and 1000.") separator = segmentation["separator"] - if not separator: - separators = ["\n\n", "。", ".", " ", ""] - else: + if separator: separator = separator.replace('\\n', '\n') - separators = [separator, ""] - character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( + character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=segmentation["max_tokens"], chunk_overlap=0, - separators=separators + fixed_separator=separator, + separators=["\n\n", "。", ".", " ", ""] ) else: # Automatic segmentation