feat: add tool labels (#2178)
This commit is contained in:
@@ -31,6 +31,7 @@ import mimetypes
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_builtin_providers = {}
|
||||
_builtin_tools_labels = {}
|
||||
|
||||
class ToolManager:
|
||||
@staticmethod
|
||||
@@ -233,7 +234,7 @@ class ToolManager:
|
||||
if len(_builtin_providers) > 0:
|
||||
return list(_builtin_providers.values())
|
||||
|
||||
builtin_providers = []
|
||||
builtin_providers: List[BuiltinToolProviderController] = []
|
||||
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
|
||||
if provider.startswith('__'):
|
||||
continue
|
||||
@@ -264,8 +265,30 @@ class ToolManager:
|
||||
# cache the builtin providers
|
||||
for provider in builtin_providers:
|
||||
_builtin_providers[provider.identity.name] = provider
|
||||
for tool in provider.get_tools():
|
||||
_builtin_tools_labels[tool.identity.name] = tool.identity.label
|
||||
|
||||
return builtin_providers
|
||||
|
||||
@staticmethod
|
||||
def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
|
||||
"""
|
||||
get the tool label
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
|
||||
:return: the label of the tool
|
||||
"""
|
||||
global _builtin_tools_labels
|
||||
if len(_builtin_tools_labels) == 0:
|
||||
# init the builtin providers
|
||||
ToolManager.list_builtin_providers()
|
||||
|
||||
if tool_name not in _builtin_tools_labels:
|
||||
return None
|
||||
|
||||
return _builtin_tools_labels[tool_name]
|
||||
|
||||
@staticmethod
|
||||
def user_list_providers(
|
||||
user_id: str,
|
||||
|
Reference in New Issue
Block a user