feat: optimize model when app create (#875)

This commit is contained in:
takatost
2023-08-16 22:29:18 +08:00
committed by GitHub
parent cc2d71c253
commit b7c29ea1b6
5 changed files with 104 additions and 23 deletions

View File

@@ -1,5 +1,6 @@
# -*- coding:utf-8 -*-
import json
import logging
from datetime import datetime
from flask_login import login_required, current_user
@@ -11,7 +12,9 @@ from controllers.console import api
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType
from events.app_event import app_was_created, app_was_deleted
from libs.helper import TimestampField
@@ -124,24 +127,34 @@ class AppListApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
default_model = ModelFactory.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_GENERATION
)
if default_model:
default_model_provider = default_model.provider_name
default_model_name = default_model.model_name
else:
raise ProviderNotInitializeError(
f"No Text Generation Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try:
default_model = ModelFactory.get_text_generation_model(
tenant_id=current_user.current_tenant_id
)
except (ProviderTokenNotInitError, LLMBadRequestError):
default_model = None
except Exception as e:
logging.exception(e)
default_model = None
if args['model_config'] is not None:
# validate config
model_config_dict = args['model_config']
model_config_dict["model"]["provider"] = default_model_provider
model_config_dict["model"]["name"] = default_model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(
current_user.current_tenant_id,
model_config_dict["model"]["provider"]
)
if not model_provider:
if not default_model:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
model_config_dict["model"]["name"] = default_model.name
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
@@ -169,10 +182,22 @@ class AppListApi(Resource):
app = App(**model_config_template['app'])
app_model_config = AppModelConfig(**model_config_template['model_config'])
model_dict = app_model_config.model_dict
model_dict['provider'] = default_model_provider
model_dict['name'] = default_model_name
app_model_config.model = json.dumps(model_dict)
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(
current_user.current_tenant_id,
app_model_config.model_dict["provider"]
)
if not model_provider:
if not default_model:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_dict = app_model_config.model_dict
model_dict['provider'] = default_model.model_provider.provider_name
model_dict['name'] = default_model.name
app_model_config.model = json.dumps(model_dict)
app.name = args['name']
app.mode = args['mode']