Feat:remove estimation of embedding cost (#7950)

Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
KVOJJJin
2024-09-04 14:41:47 +08:00
committed by GitHub
parent 83e84865be
commit 14af87527f
14 changed files with 122 additions and 162 deletions

View File

@@ -16,9 +16,7 @@ from configs import dify_config
from core.errors.error import ProviderTokenNotInitError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
@@ -255,11 +253,8 @@ class IndexingRunner:
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
tokens = 0
preview_texts = []
total_segments = 0
total_price = 0
currency = 'USD'
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
all_text_docs = []
@@ -286,54 +281,22 @@ class IndexingRunner:
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' or embedding_model_instance:
tokens += embedding_model_instance.get_text_embedding_num_tokens(
texts=[self.filter_string(document.page_content)]
)
if doc_form and doc_form == 'qa_model':
model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM
)
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language)
document_qa_list = self.format_split_text(response)
price_info = model_type_instance.get_price(
model=model_instance.model,
credentials=model_instance.credentials,
price_type=PriceType.INPUT,
tokens=total_segments * 2000,
)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(price_info.total_amount),
"currency": price_info.currency,
"qa_preview": document_qa_list,
"preview": preview_texts
}
if embedding_model_instance:
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
embedding_price_info = embedding_model_type_instance.get_price(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
total_price = '{:f}'.format(embedding_price_info.total_amount)
currency = embedding_price_info.currency
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": total_price,
"currency": currency,
"preview": preview_texts
}

View File

@@ -108,7 +108,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
else:
return text
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)
@@ -116,8 +116,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
docs = []
current_doc: list[str] = []
total = 0
index = 0
for d in splits:
_len = self._length_function(d)
_len = lengths[index]
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
@@ -145,6 +146,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
index += 1
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
@@ -493,11 +495,10 @@ class RecursiveCharacterTextSplitter(TextSplitter):
self._separators = separators or ["\n\n", "\n", " ", ""]
def _split_text(self, text: str, separators: list[str]) -> list[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
if _s == "":
separator = _s
@@ -508,25 +509,31 @@ class RecursiveCharacterTextSplitter(TextSplitter):
break
splits = _split_text_with_regex(text, separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
s_len = self._length_function(s)
if s_len < self._chunk_size:
_good_splits.append(s)
_good_splits_lengths.append(s_len)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
final_chunks.extend(merged_text)
_good_splits = []
_good_splits_lengths = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
final_chunks.extend(merged_text)
return final_chunks
def split_text(self, text: str) -> list[str]: