immediately return initialed tiktokenizer instance and remove dead code in usage of tiktokenizer (#17957)
This commit is contained in:
@@ -159,50 +159,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
)
|
||||
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_tiktoken_encoder(
|
||||
cls: type[TS],
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
) -> TS:
|
||||
"""Text splitter that uses tiktoken encoder to count length."""
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to calculate max_tokens_for_prompt. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
if model_name is not None:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
def _tiktoken_encoder(text: str) -> int:
|
||||
return len(
|
||||
enc.encode(
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
)
|
||||
|
||||
if issubclass(cls, TokenTextSplitter):
|
||||
extra_kwargs = {
|
||||
"encoding_name": encoding_name,
|
||||
"model_name": model_name,
|
||||
"allowed_special": allowed_special,
|
||||
"disallowed_special": disallowed_special,
|
||||
}
|
||||
kwargs = {**kwargs, **extra_kwargs}
|
||||
|
||||
return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs)
|
||||
|
||||
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Transform sequence of documents by splitting them."""
|
||||
return self.split_documents(list(documents))
|
||||
|
Reference in New Issue
Block a user