chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -22,9 +22,7 @@ logger = logging.getLogger(__name__)
TS = TypeVar("TS", bound="TextSplitter")
def _split_text_with_regex(
text: str, separator: str, keep_separator: bool
) -> list[str]:
def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
@@ -37,19 +35,19 @@ def _split_text_with_regex(
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if (s != "" and s != '\n')]
return [s for s in splits if (s != "" and s != "\n")]
class TextSplitter(BaseDocumentTransformer, ABC):
"""Interface for splitting text into chunks."""
def __init__(
self,
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
add_start_index: bool = False,
self,
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
add_start_index: bool = False,
) -> None:
"""Create a new TextSplitter.
@@ -62,8 +60,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
"""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller."
)
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
@@ -75,9 +72,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
def split_text(self, text: str) -> list[str]:
"""Split text into multiple components."""
def create_documents(
self, texts: list[str], metadatas: Optional[list[dict]] = None
) -> list[Document]:
def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
@@ -119,14 +114,10 @@ class TextSplitter(BaseDocumentTransformer, ABC):
index = 0
for d in splits:
_len = lengths[index]
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
):
if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
if total > self._chunk_size:
logger.warning(
f"Created a chunk of size {total}, "
f"which is longer than the specified {self._chunk_size}"
f"Created a chunk of size {total}, " f"which is longer than the specified {self._chunk_size}"
)
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
@@ -136,13 +127,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self._chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
and total > 0
total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
):
total -= self._length_function(current_doc[0]) + (
separator_len if len(current_doc) > 1 else 0
)
total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0)
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
@@ -159,28 +146,25 @@ class TextSplitter(BaseDocumentTransformer, ABC):
from transformers import PreTrainedTokenizerBase
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise ValueError(
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
)
raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")
def _huggingface_tokenizer_length(text: str) -> int:
return len(tokenizer.encode(text))
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"Please install it with `pip install transformers`."
"Could not import transformers python package. " "Please install it with `pip install transformers`."
)
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
@classmethod
def from_tiktoken_encoder(
cls: type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], Set[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
cls: type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], Set[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> TS:
"""Text splitter that uses tiktoken encoder to count length."""
try:
@@ -217,15 +201,11 @@ class TextSplitter(BaseDocumentTransformer, ABC):
return cls(length_function=_tiktoken_encoder, **kwargs)
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Transform sequence of documents by splitting them."""
return self.split_documents(list(documents))
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Asynchronously transform a sequence of documents by splitting them."""
raise NotImplementedError
@@ -267,9 +247,7 @@ class HeaderType(TypedDict):
class MarkdownHeaderTextSplitter:
"""Splitting markdown files based on specified headers."""
def __init__(
self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False
):
def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False):
"""Create a new MarkdownHeaderTextSplitter.
Args:
@@ -280,9 +258,7 @@ class MarkdownHeaderTextSplitter:
self.return_each_line = return_each_line
# Given the headers we want to split on,
# (e.g., "#, ##, etc") order by length
self.headers_to_split_on = sorted(
headers_to_split_on, key=lambda split: len(split[0]), reverse=True
)
self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True)
def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:
"""Combine lines with common metadata into chunks
@@ -292,10 +268,7 @@ class MarkdownHeaderTextSplitter:
aggregated_chunks: list[LineType] = []
for line in lines:
if (
aggregated_chunks
and aggregated_chunks[-1]["metadata"] == line["metadata"]
):
if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]:
# If the last line in the aggregated list
# has the same metadata as the current line,
# append the current content to the last lines's content
@@ -304,10 +277,7 @@ class MarkdownHeaderTextSplitter:
# Otherwise, append the current line to the aggregated list
aggregated_chunks.append(line)
return [
Document(page_content=chunk["content"], metadata=chunk["metadata"])
for chunk in aggregated_chunks
]
return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks]
def split_text(self, text: str) -> list[Document]:
"""Split markdown file
@@ -332,10 +302,9 @@ class MarkdownHeaderTextSplitter:
for sep, name in self.headers_to_split_on:
# Check if line starts with a header that we intend to split on
if stripped_line.startswith(sep) and (
# Header with no text OR header is followed by space
# Both are valid conditions that sep is being used a header
len(stripped_line) == len(sep)
or stripped_line[len(sep)] == " "
# Header with no text OR header is followed by space
# Both are valid conditions that sep is being used a header
len(stripped_line) == len(sep) or stripped_line[len(sep)] == " "
):
# Ensure we are tracking the header as metadata
if name is not None:
@@ -343,10 +312,7 @@ class MarkdownHeaderTextSplitter:
current_header_level = sep.count("#")
# Pop out headers of lower or same level from the stack
while (
header_stack
and header_stack[-1]["level"] >= current_header_level
):
while header_stack and header_stack[-1]["level"] >= current_header_level:
# We have encountered a new header
# at the same or higher level
popped_header = header_stack.pop()
@@ -359,7 +325,7 @@ class MarkdownHeaderTextSplitter:
header: HeaderType = {
"level": current_header_level,
"name": name,
"data": stripped_line[len(sep):].strip(),
"data": stripped_line[len(sep) :].strip(),
}
header_stack.append(header)
# Update initial_metadata with the current header
@@ -392,9 +358,7 @@ class MarkdownHeaderTextSplitter:
current_metadata = initial_metadata.copy()
if current_content:
lines_with_metadata.append(
{"content": "\n".join(current_content), "metadata": current_metadata}
)
lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata})
# lines_with_metadata has each line with associated header metadata
# aggregate these into chunks based on common metadata
@@ -402,8 +366,7 @@ class MarkdownHeaderTextSplitter:
return self.aggregate_lines_to_chunks(lines_with_metadata)
else:
return [
Document(page_content=chunk["content"], metadata=chunk["metadata"])
for chunk in lines_with_metadata
Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata
]
@@ -436,12 +399,12 @@ class TokenTextSplitter(TextSplitter):
"""Splitting text to tokens using model tokenizer."""
def __init__(
self,
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], Set[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
self,
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], Set[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(**kwargs)
@@ -488,10 +451,10 @@ class RecursiveCharacterTextSplitter(TextSplitter):
"""
def __init__(
self,
separators: Optional[list[str]] = None,
keep_separator: bool = True,
**kwargs: Any,
self,
separators: Optional[list[str]] = None,
keep_separator: bool = True,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
@@ -508,7 +471,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
break
if re.search(_s, text):
separator = _s
new_separators = separators[i + 1:]
new_separators = separators[i + 1 :]
break
splits = _split_text_with_regex(text, separator, self._keep_separator)