Feat/model as tool (#2744)

This commit is contained in:
Yeuoly
2024-03-08 15:22:55 +08:00
committed by GitHub
parent 3231a8c51c
commit 40c646cf7a
26 changed files with 840 additions and 43 deletions

View File

@@ -22,6 +22,7 @@ from core.tools.utils.encoder import serialize_base_model_array, serialize_base_
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider
from services.model_provider_service import ModelProviderService
class ToolManageService:
@@ -50,11 +51,13 @@ class ToolManageService:
:param provider: the provider dict
"""
url_prefix = (current_app.config.get("CONSOLE_API_URL")
+ "/console/api/workspaces/current/tool-provider/builtin/")
+ "/console/api/workspaces/current/tool-provider/")
if 'icon' in provider:
if provider['type'] == UserToolProvider.ProviderType.BUILTIN.value:
provider['icon'] = url_prefix + provider['name'] + '/icon'
provider['icon'] = url_prefix + 'builtin/' + provider['name'] + '/icon'
elif provider['type'] == UserToolProvider.ProviderType.MODEL.value:
provider['icon'] = url_prefix + 'model/' + provider['name'] + '/icon'
elif provider['type'] == UserToolProvider.ProviderType.API.value:
try:
provider['icon'] = json.loads(provider['icon'])
@@ -505,6 +508,46 @@ class ToolManageService:
return icon_bytes, mime_type
@staticmethod
def get_model_tool_provider_icon(
provider: str
):
"""
get tool provider icon and it's mimetype
"""
service = ModelProviderService()
icon_bytes, mime_type = service.get_model_provider_icon(provider=provider, icon_type='icon_small', lang='en_US')
if icon_bytes is None:
raise ValueError(f'provider {provider} does not exists')
return icon_bytes, mime_type
@staticmethod
def list_model_tool_provider_tools(
user_id: str, tenant_id: str, provider: str
):
"""
list model tool provider tools
"""
provider_controller = ToolManager.get_model_provider(tenant_id=tenant_id, provider_name=provider)
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
result = [
UserTool(
author=tool.identity.author,
name=tool.identity.name,
label=tool.identity.label,
description=tool.description.human,
parameters=tool.parameters or []
) for tool in tools
]
return json.loads(
serialize_base_model_array(result)
)
@staticmethod
def delete_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str