Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -1,16 +1,19 @@
|
||||
import io
|
||||
|
||||
from flask import send_file
|
||||
from flask_login import current_user
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from services.provider_service import ProviderService
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class ModelProviderListApi(Resource):
|
||||
@@ -22,13 +25,36 @@ class ModelProviderListApi(Resource):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_type', type=str, required=False, nullable=True, location='args')
|
||||
parser.add_argument('model_type', type=str, required=False, nullable=True,
|
||||
choices=[mt.value for mt in ModelType], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_list = provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get('model_type'))
|
||||
model_provider_service = ModelProviderService()
|
||||
provider_list = model_provider_service.get_provider_list(
|
||||
tenant_id=tenant_id,
|
||||
model_type=args.get('model_type')
|
||||
)
|
||||
|
||||
return provider_list
|
||||
return jsonable_encoder({"data": provider_list})
|
||||
|
||||
|
||||
class ModelProviderCredentialApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
credentials = model_provider_service.get_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider
|
||||
)
|
||||
|
||||
return {
|
||||
"credentials": credentials
|
||||
}
|
||||
|
||||
|
||||
class ModelProviderValidateApi(Resource):
|
||||
@@ -36,21 +62,24 @@ class ModelProviderValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
def post(self, provider: str):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
provider_service.custom_provider_config_validate(
|
||||
provider_name=provider_name,
|
||||
config=args['config']
|
||||
model_provider_service.provider_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credentials=args['credentials']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
@@ -64,26 +93,26 @@ class ModelProviderValidateApi(Resource):
|
||||
return response
|
||||
|
||||
|
||||
class ModelProviderUpdateApi(Resource):
|
||||
class ModelProviderApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
provider_service.save_custom_provider_config(
|
||||
model_provider_service.save_provider_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
config=args['config']
|
||||
provider=provider,
|
||||
credentials=args['credentials']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
@@ -93,109 +122,36 @@ class ModelProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_name: str):
|
||||
def delete(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.delete_custom_provider(
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name
|
||||
provider=provider
|
||||
)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class ModelProviderModelValidateApi(Resource):
|
||||
class ModelProviderIconApi(Resource):
|
||||
"""
|
||||
Get model provider icon
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
provider_service.custom_provider_model_config_validate(
|
||||
provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type=args['model_type'],
|
||||
config=args['config']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ModelProviderModelUpdateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
try:
|
||||
provider_service.add_or_save_custom_provider_model_config(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type=args['model_type'],
|
||||
config=args['config']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_name: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.delete_custom_provider_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type=args['model_type']
|
||||
def get(self, provider: str, icon_type: str, lang: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
icon, mimetype = model_provider_service.get_model_provider_icon(
|
||||
provider=provider,
|
||||
icon_type=icon_type,
|
||||
lang=lang
|
||||
)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
return send_file(io.BytesIO(icon), mimetype=mimetype)
|
||||
|
||||
|
||||
class PreferredProviderTypeUpdateApi(Resource):
|
||||
@@ -203,71 +159,36 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
|
||||
choices=['system', 'custom'], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.switch_preferred_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.switch_preferred_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
preferred_provider_type=args['preferred_provider_type']
|
||||
)
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class ModelProviderModelParameterRuleApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
try:
|
||||
parameter_rules = provider_service.get_model_parameter_rules(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type='text-generation'
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"Current Text Generation Model is invalid. Please switch to the available model.")
|
||||
|
||||
rules = {
|
||||
k: {
|
||||
'enabled': v.enabled,
|
||||
'min': v.min,
|
||||
'max': v.max,
|
||||
'default': v.default,
|
||||
'precision': v.precision
|
||||
}
|
||||
for k, v in vars(parameter_rules).items()
|
||||
}
|
||||
|
||||
return rules
|
||||
|
||||
|
||||
class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
if provider_name != 'anthropic':
|
||||
raise ValueError(f'provider name {provider_name} is invalid')
|
||||
def get(self, provider: str):
|
||||
if provider != 'anthropic':
|
||||
raise ValueError(f'provider name {provider} is invalid')
|
||||
|
||||
data = BillingService.get_model_provider_payment_link(provider_name=provider_name,
|
||||
data = BillingService.get_model_provider_payment_link(provider_name=provider,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account_id=current_user.id)
|
||||
return data
|
||||
@@ -277,11 +198,11 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
provider_service = ProviderService()
|
||||
result = provider_service.free_quota_submit(
|
||||
def post(self, provider: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
result = model_provider_service.free_quota_submit(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name
|
||||
provider=provider
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -291,15 +212,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
def get(self, provider: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
result = provider_service.free_quota_qualification_verify(
|
||||
model_provider_service = ModelProviderService()
|
||||
result = model_provider_service.free_quota_qualification_verify(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
provider=provider,
|
||||
token=args['token']
|
||||
)
|
||||
|
||||
@@ -307,19 +228,18 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
|
||||
|
||||
|
||||
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
|
||||
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
|
||||
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
|
||||
api.add_resource(ModelProviderModelValidateApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/models/validate')
|
||||
api.add_resource(ModelProviderModelUpdateApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/models')
|
||||
|
||||
api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials')
|
||||
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate')
|
||||
api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>')
|
||||
api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/'
|
||||
'<string:icon_type>/<string:lang>')
|
||||
|
||||
api.add_resource(PreferredProviderTypeUpdateApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
|
||||
api.add_resource(ModelProviderModelParameterRuleApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
|
||||
'/workspaces/current/model-providers/<string:provider>/preferred-provider-type')
|
||||
api.add_resource(ModelProviderPaymentCheckoutUrlApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')
|
||||
'/workspaces/current/model-providers/<string:provider>/checkout-url')
|
||||
api.add_resource(ModelProviderFreeQuotaSubmitApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
|
||||
'/workspaces/current/model-providers/<string:provider>/free-quota-submit')
|
||||
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')
|
||||
'/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify')
|
||||
|
Reference in New Issue
Block a user