chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -22,9 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
TS = TypeVar("TS", bound="TextSplitter")
|
||||
|
||||
|
||||
def _split_text_with_regex(
|
||||
text: str, separator: str, keep_separator: bool
|
||||
) -> list[str]:
|
||||
def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
if keep_separator:
|
||||
@@ -37,19 +35,19 @@ def _split_text_with_regex(
|
||||
splits = re.split(separator, text)
|
||||
else:
|
||||
splits = list(text)
|
||||
return [s for s in splits if (s != "" and s != '\n')]
|
||||
return [s for s in splits if (s != "" and s != "\n")]
|
||||
|
||||
|
||||
class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
"""Interface for splitting text into chunks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: Callable[[str], int] = len,
|
||||
keep_separator: bool = False,
|
||||
add_start_index: bool = False,
|
||||
self,
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: Callable[[str], int] = len,
|
||||
keep_separator: bool = False,
|
||||
add_start_index: bool = False,
|
||||
) -> None:
|
||||
"""Create a new TextSplitter.
|
||||
|
||||
@@ -62,8 +60,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
"""
|
||||
if chunk_overlap > chunk_size:
|
||||
raise ValueError(
|
||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
||||
f"({chunk_size}), should be smaller."
|
||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller."
|
||||
)
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
@@ -75,9 +72,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split text into multiple components."""
|
||||
|
||||
def create_documents(
|
||||
self, texts: list[str], metadatas: Optional[list[dict]] = None
|
||||
) -> list[Document]:
|
||||
def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]:
|
||||
"""Create documents from a list of texts."""
|
||||
_metadatas = metadatas or [{}] * len(texts)
|
||||
documents = []
|
||||
@@ -119,14 +114,10 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
index = 0
|
||||
for d in splits:
|
||||
_len = lengths[index]
|
||||
if (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
||||
> self._chunk_size
|
||||
):
|
||||
if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
|
||||
if total > self._chunk_size:
|
||||
logger.warning(
|
||||
f"Created a chunk of size {total}, "
|
||||
f"which is longer than the specified {self._chunk_size}"
|
||||
f"Created a chunk of size {total}, " f"which is longer than the specified {self._chunk_size}"
|
||||
)
|
||||
if len(current_doc) > 0:
|
||||
doc = self._join_docs(current_doc, separator)
|
||||
@@ -136,13 +127,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
# - we have a larger chunk than in the chunk overlap
|
||||
# - or if we still have any chunks and the length is long
|
||||
while total > self._chunk_overlap or (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
||||
> self._chunk_size
|
||||
and total > 0
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
|
||||
):
|
||||
total -= self._length_function(current_doc[0]) + (
|
||||
separator_len if len(current_doc) > 1 else 0
|
||||
)
|
||||
total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0)
|
||||
current_doc = current_doc[1:]
|
||||
current_doc.append(d)
|
||||
total += _len + (separator_len if len(current_doc) > 1 else 0)
|
||||
@@ -159,28 +146,25 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise ValueError(
|
||||
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
|
||||
)
|
||||
raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")
|
||||
|
||||
def _huggingface_tokenizer_length(text: str) -> int:
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"Please install it with `pip install transformers`."
|
||||
"Could not import transformers python package. " "Please install it with `pip install transformers`."
|
||||
)
|
||||
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_tiktoken_encoder(
|
||||
cls: type[TS],
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
cls: type[TS],
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
) -> TS:
|
||||
"""Text splitter that uses tiktoken encoder to count length."""
|
||||
try:
|
||||
@@ -217,15 +201,11 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
|
||||
return cls(length_function=_tiktoken_encoder, **kwargs)
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Transform sequence of documents by splitting them."""
|
||||
return self.split_documents(list(documents))
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Asynchronously transform a sequence of documents by splitting them."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -267,9 +247,7 @@ class HeaderType(TypedDict):
|
||||
class MarkdownHeaderTextSplitter:
|
||||
"""Splitting markdown files based on specified headers."""
|
||||
|
||||
def __init__(
|
||||
self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False
|
||||
):
|
||||
def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False):
|
||||
"""Create a new MarkdownHeaderTextSplitter.
|
||||
|
||||
Args:
|
||||
@@ -280,9 +258,7 @@ class MarkdownHeaderTextSplitter:
|
||||
self.return_each_line = return_each_line
|
||||
# Given the headers we want to split on,
|
||||
# (e.g., "#, ##, etc") order by length
|
||||
self.headers_to_split_on = sorted(
|
||||
headers_to_split_on, key=lambda split: len(split[0]), reverse=True
|
||||
)
|
||||
self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True)
|
||||
|
||||
def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:
|
||||
"""Combine lines with common metadata into chunks
|
||||
@@ -292,10 +268,7 @@ class MarkdownHeaderTextSplitter:
|
||||
aggregated_chunks: list[LineType] = []
|
||||
|
||||
for line in lines:
|
||||
if (
|
||||
aggregated_chunks
|
||||
and aggregated_chunks[-1]["metadata"] == line["metadata"]
|
||||
):
|
||||
if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]:
|
||||
# If the last line in the aggregated list
|
||||
# has the same metadata as the current line,
|
||||
# append the current content to the last lines's content
|
||||
@@ -304,10 +277,7 @@ class MarkdownHeaderTextSplitter:
|
||||
# Otherwise, append the current line to the aggregated list
|
||||
aggregated_chunks.append(line)
|
||||
|
||||
return [
|
||||
Document(page_content=chunk["content"], metadata=chunk["metadata"])
|
||||
for chunk in aggregated_chunks
|
||||
]
|
||||
return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks]
|
||||
|
||||
def split_text(self, text: str) -> list[Document]:
|
||||
"""Split markdown file
|
||||
@@ -332,10 +302,9 @@ class MarkdownHeaderTextSplitter:
|
||||
for sep, name in self.headers_to_split_on:
|
||||
# Check if line starts with a header that we intend to split on
|
||||
if stripped_line.startswith(sep) and (
|
||||
# Header with no text OR header is followed by space
|
||||
# Both are valid conditions that sep is being used a header
|
||||
len(stripped_line) == len(sep)
|
||||
or stripped_line[len(sep)] == " "
|
||||
# Header with no text OR header is followed by space
|
||||
# Both are valid conditions that sep is being used a header
|
||||
len(stripped_line) == len(sep) or stripped_line[len(sep)] == " "
|
||||
):
|
||||
# Ensure we are tracking the header as metadata
|
||||
if name is not None:
|
||||
@@ -343,10 +312,7 @@ class MarkdownHeaderTextSplitter:
|
||||
current_header_level = sep.count("#")
|
||||
|
||||
# Pop out headers of lower or same level from the stack
|
||||
while (
|
||||
header_stack
|
||||
and header_stack[-1]["level"] >= current_header_level
|
||||
):
|
||||
while header_stack and header_stack[-1]["level"] >= current_header_level:
|
||||
# We have encountered a new header
|
||||
# at the same or higher level
|
||||
popped_header = header_stack.pop()
|
||||
@@ -359,7 +325,7 @@ class MarkdownHeaderTextSplitter:
|
||||
header: HeaderType = {
|
||||
"level": current_header_level,
|
||||
"name": name,
|
||||
"data": stripped_line[len(sep):].strip(),
|
||||
"data": stripped_line[len(sep) :].strip(),
|
||||
}
|
||||
header_stack.append(header)
|
||||
# Update initial_metadata with the current header
|
||||
@@ -392,9 +358,7 @@ class MarkdownHeaderTextSplitter:
|
||||
current_metadata = initial_metadata.copy()
|
||||
|
||||
if current_content:
|
||||
lines_with_metadata.append(
|
||||
{"content": "\n".join(current_content), "metadata": current_metadata}
|
||||
)
|
||||
lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata})
|
||||
|
||||
# lines_with_metadata has each line with associated header metadata
|
||||
# aggregate these into chunks based on common metadata
|
||||
@@ -402,8 +366,7 @@ class MarkdownHeaderTextSplitter:
|
||||
return self.aggregate_lines_to_chunks(lines_with_metadata)
|
||||
else:
|
||||
return [
|
||||
Document(page_content=chunk["content"], metadata=chunk["metadata"])
|
||||
for chunk in lines_with_metadata
|
||||
Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata
|
||||
]
|
||||
|
||||
|
||||
@@ -436,12 +399,12 @@ class TokenTextSplitter(TextSplitter):
|
||||
"""Splitting text to tokens using model tokenizer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
self,
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
@@ -488,10 +451,10 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
separators: Optional[list[str]] = None,
|
||||
keep_separator: bool = True,
|
||||
**kwargs: Any,
|
||||
self,
|
||||
separators: Optional[list[str]] = None,
|
||||
keep_separator: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(keep_separator=keep_separator, **kwargs)
|
||||
@@ -508,7 +471,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
break
|
||||
if re.search(_s, text):
|
||||
separator = _s
|
||||
new_separators = separators[i + 1:]
|
||||
new_separators = separators[i + 1 :]
|
||||
break
|
||||
|
||||
splits = _split_text_with_regex(text, separator, self._keep_separator)
|
||||
|
Reference in New Issue
Block a user