Feature/mutil embedding model (#908)

Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
Jyong
2023-08-18 17:37:31 +08:00
committed by GitHub
parent 4420281d96
commit db7156dafd
54 changed files with 1704 additions and 278 deletions

View File

@@ -10,13 +10,15 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner
from core.model_providers.error import LLMBadRequestError
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelType
from libs.helper import TimestampField
from extensions.ext_database import db
from models.dataset import DocumentSegment, Document
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService
from services.provider_service import ProviderService
dataset_detail_fields = {
'id': fields.String,
@@ -33,6 +35,9 @@ dataset_detail_fields = {
'created_at': TimestampField,
'updated_by': fields.String,
'updated_at': TimestampField,
'embedding_model': fields.String,
'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean
}
dataset_query_detail_fields = {
@@ -74,8 +79,22 @@ class DatasetListApi(Resource):
datasets, total = DatasetService.get_datasets(page, limit, provider,
current_user.current_tenant_id, current_user)
# check embedding setting
provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
# if len(valid_model_list) == 0:
# raise ProviderNotInitializeError(
# f"No Embedding Model available. Please configure a valid provider "
# f"in the Settings -> Model Provider.")
model_names = [item['model_name'] for item in valid_model_list]
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item['embedding_model'] in model_names:
item['embedding_available'] = True
else:
item['embedding_available'] = False
response = {
'data': marshal(datasets, dataset_detail_fields),
'data': data,
'has_more': len(datasets) == limit,
'limit': limit,
'total': total,
@@ -99,7 +118,6 @@ class DatasetListApi(Resource):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
@@ -233,6 +251,8 @@ class DatasetIndexingEstimateApi(Resource):
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
@@ -250,11 +270,14 @@ class DatasetIndexingEstimateApi(Resource):
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
args['process_rule'], args['doc_form'])
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner()
@@ -262,11 +285,14 @@ class DatasetIndexingEstimateApi(Resource):
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'])
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
else:
raise ValueError('Data source type not support')
return response, 200