feat: server multi models support (#799)

This commit is contained in:
takatost
2023-08-12 00:57:00 +08:00
committed by GitHub
parent d8b712b325
commit 5fa2161b05
213 changed files with 10556 additions and 2579 deletions

View File

@@ -1,4 +1,3 @@
import concurrent
import datetime
import json
import logging
@@ -6,7 +5,6 @@ import re
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, List, cast
from flask_login import current_user
@@ -18,11 +16,10 @@ from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.generator.llm_generator import LLMGenerator
from core.index.index import IndexBuilder
from core.llm.error import ProviderTokenNotInitError
from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_open_ai import StreamableOpenAI
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import MessageType
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@@ -35,9 +32,8 @@ from models.source import DataSourceBinding
class IndexingRunner:
def __init__(self, embedding_model_name: str = "text-embedding-ada-002"):
def __init__(self):
self.storage = storage
self.embedding_model_name = embedding_model_name
def run(self, dataset_documents: List[DatasetDocument]):
"""Run the indexing process."""
@@ -227,11 +223,15 @@ class IndexingRunner:
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict,
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
doc_form: str = None) -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
tokens = 0
preview_texts = []
total_segments = 0
@@ -253,44 +253,49 @@ class IndexingRunner:
splitter=splitter,
processing_rule=processing_rule
)
total_segments += len(documents)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
self.filter_string(document.page_content))
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if doc_form and doc_form == 'qa_model':
if len(preview_texts) > 0:
# qa model document
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=current_user.current_tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
}
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
"currency": embedding_model.get_currency(),
"preview": preview_texts
}
def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
# load data from notion
tokens = 0
preview_texts = []
@@ -336,31 +341,31 @@ class IndexingRunner:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
tokens += embedding_model.get_num_tokens(document.page_content)
text_generation_model = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if doc_form and doc_form == 'qa_model':
if len(preview_texts) > 0:
# qa model document
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=current_user.current_tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
}
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
"currency": embedding_model.get_currency(),
"preview": preview_texts
}
@@ -459,7 +464,6 @@ class IndexingRunner:
doc_store = DatesetDocumentStore(
dataset=dataset,
user_id=dataset_document.created_by,
embedding_model_name=self.embedding_model_name,
document_id=dataset_document.id
)
@@ -513,17 +517,12 @@ class IndexingRunner:
all_documents.extend(split_documents)
# processing qa document
if document_form == 'qa_model':
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=2000
)
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i:i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents})
'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
@@ -531,13 +530,13 @@ class IndexingRunner:
return all_qa_documents
return all_documents
def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents):
def format_qa_document(self, tenant_id: str, document_node, all_qa_documents):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return
try:
# qa model document
response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
@@ -638,6 +637,10 @@ class IndexingRunner:
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
)
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
tokens = 0
@@ -648,7 +651,7 @@ class IndexingRunner:
chunk_documents = documents[i:i + chunk_size]
tokens += sum(
TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
embedding_model.get_num_tokens(document.page_content)
for document in chunk_documents
)