Model Runtime (#1858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
takatost
2024-01-02 23:42:00 +08:00
committed by GitHub
parent e91dd28a76
commit d069c668f8
807 changed files with 171310 additions and 23806 deletions

View File

@@ -1,16 +1,19 @@
import io
from flask import send_file
from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import CredentialsValidateFailedError
from services.provider_service import ProviderService
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
class ModelProviderListApi(Resource):
@@ -22,13 +25,36 @@ class ModelProviderListApi(Resource):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=False, nullable=True, location='args')
parser.add_argument('model_type', type=str, required=False, nullable=True,
choices=[mt.value for mt in ModelType], location='args')
args = parser.parse_args()
provider_service = ProviderService()
provider_list = provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get('model_type'))
model_provider_service = ModelProviderService()
provider_list = model_provider_service.get_provider_list(
tenant_id=tenant_id,
model_type=args.get('model_type')
)
return provider_list
return jsonable_encoder({"data": provider_list})
class ModelProviderCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_provider_credentials(
tenant_id=tenant_id,
provider=provider
)
return {
"credentials": credentials
}
class ModelProviderValidateApi(Resource):
@@ -36,21 +62,24 @@ class ModelProviderValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
def post(self, provider: str):
parser = reqparse.RequestParser()
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
result = True
error = None
try:
provider_service.custom_provider_config_validate(
provider_name=provider_name,
config=args['config']
model_provider_service.provider_credentials_validate(
tenant_id=tenant_id,
provider=provider,
credentials=args['credentials']
)
except CredentialsValidateFailedError as ex:
result = False
@@ -64,26 +93,26 @@ class ModelProviderValidateApi(Resource):
return response
class ModelProviderUpdateApi(Resource):
class ModelProviderApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
def post(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
model_provider_service = ModelProviderService()
try:
provider_service.save_custom_provider_config(
model_provider_service.save_provider_credentials(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
config=args['config']
provider=provider,
credentials=args['credentials']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
@@ -93,109 +122,36 @@ class ModelProviderUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_name: str):
def delete(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
provider_service = ProviderService()
provider_service.delete_custom_provider(
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credentials(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name
provider=provider
)
return {'result': 'success'}, 204
class ModelProviderModelValidateApi(Resource):
class ModelProviderIconApi(Resource):
"""
Get model provider icon
"""
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
result = True
error = None
try:
provider_service.custom_provider_model_config_validate(
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type'],
config=args['config']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
class ModelProviderModelUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
try:
provider_service.add_or_save_custom_provider_model_config(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type'],
config=args['config']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
args = parser.parse_args()
provider_service = ProviderService()
provider_service.delete_custom_provider_model(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type']
def get(self, provider: str, icon_type: str, lang: str):
model_provider_service = ModelProviderService()
icon, mimetype = model_provider_service.get_model_provider_icon(
provider=provider,
icon_type=icon_type,
lang=lang
)
return {'result': 'success'}, 204
return send_file(io.BytesIO(icon), mimetype=mimetype)
class PreferredProviderTypeUpdateApi(Resource):
@@ -203,71 +159,36 @@ class PreferredProviderTypeUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
def post(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
choices=['system', 'custom'], location='json')
args = parser.parse_args()
provider_service = ProviderService()
provider_service.switch_preferred_provider(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
model_provider_service = ModelProviderService()
model_provider_service.switch_preferred_provider(
tenant_id=tenant_id,
provider=provider,
preferred_provider_type=args['preferred_provider_type']
)
return {'result': 'success'}
class ModelProviderModelParameterRuleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
args = parser.parse_args()
provider_service = ProviderService()
try:
parameter_rules = provider_service.get_model_parameter_rules(
tenant_id=current_user.current_tenant_id,
model_provider_name=provider_name,
model_name=args['model_name'],
model_type='text-generation'
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"Current Text Generation Model is invalid. Please switch to the available model.")
rules = {
k: {
'enabled': v.enabled,
'min': v.min,
'max': v.max,
'default': v.default,
'precision': v.precision
}
for k, v in vars(parameter_rules).items()
}
return rules
class ModelProviderPaymentCheckoutUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
if provider_name != 'anthropic':
raise ValueError(f'provider name {provider_name} is invalid')
def get(self, provider: str):
if provider != 'anthropic':
raise ValueError(f'provider name {provider} is invalid')
data = BillingService.get_model_provider_payment_link(provider_name=provider_name,
data = BillingService.get_model_provider_payment_link(provider_name=provider,
tenant_id=current_user.current_tenant_id,
account_id=current_user.id)
return data
@@ -277,11 +198,11 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
provider_service = ProviderService()
result = provider_service.free_quota_submit(
def post(self, provider: str):
model_provider_service = ModelProviderService()
result = model_provider_service.free_quota_submit(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name
provider=provider
)
return result
@@ -291,15 +212,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
def get(self, provider: str):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
args = parser.parse_args()
provider_service = ProviderService()
result = provider_service.free_quota_qualification_verify(
model_provider_service = ModelProviderService()
result = model_provider_service.free_quota_qualification_verify(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
provider=provider,
token=args['token']
)
@@ -307,19 +228,18 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
api.add_resource(ModelProviderModelValidateApi,
'/workspaces/current/model-providers/<string:provider_name>/models/validate')
api.add_resource(ModelProviderModelUpdateApi,
'/workspaces/current/model-providers/<string:provider_name>/models')
api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials')
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate')
api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>')
api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/'
'<string:icon_type>/<string:lang>')
api.add_resource(PreferredProviderTypeUpdateApi,
'/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
api.add_resource(ModelProviderModelParameterRuleApi,
'/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
'/workspaces/current/model-providers/<string:provider>/preferred-provider-type')
api.add_resource(ModelProviderPaymentCheckoutUrlApi,
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')
'/workspaces/current/model-providers/<string:provider>/checkout-url')
api.add_resource(ModelProviderFreeQuotaSubmitApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
'/workspaces/current/model-providers/<string:provider>/free-quota-submit')
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')
'/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify')