feat: optimize split rule when use custom split segment identifier (#35)
This commit is contained in:
68
api/core/index/spiltter/fixed_text_splitter.py
Normal file
68
api/core/index/spiltter/fixed_text_splitter.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""Functionality for splitting text."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
|
||||||
|
class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||||
|
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
|
||||||
|
"""Create a new TextSplitter."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._fixed_separator = fixed_separator
|
||||||
|
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||||
|
|
||||||
|
def split_text(self, text: str) -> List[str]:
|
||||||
|
"""Split incoming text and return chunks."""
|
||||||
|
if self._fixed_separator:
|
||||||
|
chunks = text.split(self._fixed_separator)
|
||||||
|
else:
|
||||||
|
chunks = list(text)
|
||||||
|
|
||||||
|
final_chunks = []
|
||||||
|
for chunk in chunks:
|
||||||
|
if self._length_function(chunk) > self._chunk_size:
|
||||||
|
final_chunks.extend(self.recursive_split_text(chunk))
|
||||||
|
else:
|
||||||
|
final_chunks.append(chunk)
|
||||||
|
|
||||||
|
return final_chunks
|
||||||
|
|
||||||
|
def recursive_split_text(self, text: str) -> List[str]:
|
||||||
|
"""Split incoming text and return chunks."""
|
||||||
|
final_chunks = []
|
||||||
|
# Get appropriate separator to use
|
||||||
|
separator = self._separators[-1]
|
||||||
|
for _s in self._separators:
|
||||||
|
if _s == "":
|
||||||
|
separator = _s
|
||||||
|
break
|
||||||
|
if _s in text:
|
||||||
|
separator = _s
|
||||||
|
break
|
||||||
|
# Now that we have the separator, split the text
|
||||||
|
if separator:
|
||||||
|
splits = text.split(separator)
|
||||||
|
else:
|
||||||
|
splits = list(text)
|
||||||
|
# Now go merging things, recursively splitting longer texts.
|
||||||
|
_good_splits = []
|
||||||
|
for s in splits:
|
||||||
|
if self._length_function(s) < self._chunk_size:
|
||||||
|
_good_splits.append(s)
|
||||||
|
else:
|
||||||
|
if _good_splits:
|
||||||
|
merged_text = self._merge_splits(_good_splits, separator)
|
||||||
|
final_chunks.extend(merged_text)
|
||||||
|
_good_splits = []
|
||||||
|
other_info = self.recursive_split_text(s)
|
||||||
|
final_chunks.extend(other_info)
|
||||||
|
if _good_splits:
|
||||||
|
merged_text = self._merge_splits(_good_splits, separator)
|
||||||
|
final_chunks.extend(merged_text)
|
||||||
|
return final_chunks
|
@@ -18,6 +18,7 @@ from core.docstore.dataset_docstore import DatesetDocumentStore
|
|||||||
from core.index.keyword_table_index import KeywordTableIndex
|
from core.index.keyword_table_index import KeywordTableIndex
|
||||||
from core.index.readers.html_parser import HTMLParser
|
from core.index.readers.html_parser import HTMLParser
|
||||||
from core.index.readers.pdf_parser import PDFParser
|
from core.index.readers.pdf_parser import PDFParser
|
||||||
|
from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
|
||||||
from core.index.vector_index import VectorIndex
|
from core.index.vector_index import VectorIndex
|
||||||
from core.llm.token_calculator import TokenCalculator
|
from core.llm.token_calculator import TokenCalculator
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@@ -267,16 +268,14 @@ class IndexingRunner:
|
|||||||
raise ValueError("Custom segment length should be between 50 and 1000.")
|
raise ValueError("Custom segment length should be between 50 and 1000.")
|
||||||
|
|
||||||
separator = segmentation["separator"]
|
separator = segmentation["separator"]
|
||||||
if not separator:
|
if separator:
|
||||||
separators = ["\n\n", "。", ".", " ", ""]
|
|
||||||
else:
|
|
||||||
separator = separator.replace('\\n', '\n')
|
separator = separator.replace('\\n', '\n')
|
||||||
separators = [separator, ""]
|
|
||||||
|
|
||||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||||
chunk_size=segmentation["max_tokens"],
|
chunk_size=segmentation["max_tokens"],
|
||||||
chunk_overlap=0,
|
chunk_overlap=0,
|
||||||
separators=separators
|
fixed_separator=separator,
|
||||||
|
separators=["\n\n", "。", ".", " ", ""]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Automatic segmentation
|
# Automatic segmentation
|
||||||
|
Reference in New Issue
Block a user