feat: server multi models support (#799)

This commit is contained in:
takatost
2023-08-12 00:57:00 +08:00
committed by GitHub
parent d8b712b325
commit 5fa2161b05
213 changed files with 10556 additions and 2579 deletions

View File

@@ -255,7 +255,7 @@ class DataSourceNotionApi(Resource):
# validate args
DocumentService.estimate_args_validate(args)
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule'])
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
return response, 200

View File

@@ -5,10 +5,13 @@ from flask_restful import Resource, reqparse, fields, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner
from core.model_providers.error import LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from libs.helper import TimestampField
from extensions.ext_database import db
from models.dataset import DocumentSegment, Document
@@ -97,6 +100,15 @@ class DatasetListApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id,
@@ -235,12 +247,26 @@ class DatasetIndexingEstimateApi(Resource):
raise NotFound("File not found.")
indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form'])
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
args['process_rule'], args['doc_form'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'])
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
else:
raise ValueError('Data source type not support')
return response, 200

View File

@@ -18,7 +18,9 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client
from libs.helper import TimestampField
from extensions.ext_database import db
@@ -280,6 +282,15 @@ class DatasetDocumentListApi(Resource):
# validate args
DocumentService.document_create_args_validate(args)
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError as ex:
@@ -319,6 +330,15 @@ class DatasetInitApi(Resource):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
args = parser.parse_args()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
# validate args
DocumentService.document_create_args_validate(args)
@@ -384,7 +404,13 @@ class DocumentIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate([file], data_process_rule_dict)
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
data_process_rule_dict)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
return response
@@ -445,12 +471,24 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate(file_details, data_process_rule_dict)
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
data_process_rule_dict)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
elif dataset.data_source_type:
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(info_list,
data_process_rule_dict)
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
info_list,
data_process_rule_dict)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
else:
raise ValueError('Data source type not support')
return response

View File

@@ -11,7 +11,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import TimestampField
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
@@ -102,6 +102,8 @@ class HitTestingApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))