feat: add tool labels (#2178)

This commit is contained in:
Yeuoly
2024-01-24 20:14:45 +08:00
committed by GitHub
parent 0940084fd2
commit 7cb75cb2e7
8 changed files with 90 additions and 3 deletions

View File

@@ -18,6 +18,7 @@ from core.model_runtime.entities.message_entities import (AssistantPromptMessage
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool_file_manager import ToolFileManager
from core.tools.tool_manager import ToolManager
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created
@@ -281,7 +282,7 @@ class GenerateTaskPipeline:
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent):
agent_thought = (
agent_thought: MessageAgentThought = (
db.session.query(MessageAgentThought)
.filter(MessageAgentThought.id == event.agent_thought_id)
.first()
@@ -298,6 +299,7 @@ class GenerateTaskPipeline:
'thought': agent_thought.thought,
'observation': agent_thought.observation,
'tool': agent_thought.tool,
'tool_labels': agent_thought.tool_labels,
'tool_input': agent_thought.tool_input,
'created_at': int(self._message.created_at.timestamp()),
'message_files': agent_thought.files

View File

@@ -396,6 +396,7 @@ class BaseAssistantApplicationRunner(AppRunner):
message_chain_id=None,
thought='',
tool=tool_name,
tool_labels_str='{}',
tool_input=tool_input,
message=message,
message_token=0,
@@ -469,6 +470,21 @@ class BaseAssistantApplicationRunner(AppRunner):
agent_thought.tokens = llm_usage.total_tokens
agent_thought.total_price = llm_usage.total_price
# check if tool labels is not empty
labels = agent_thought.tool_labels or {}
tools = agent_thought.tool.split(';') if agent_thought.tool else []
for tool in tools:
if not tool:
continue
if tool not in labels:
tool_label = ToolManager.get_tool_label(tool)
if tool_label:
labels[tool] = tool_label.to_dict()
else:
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
agent_thought.tool_labels_str = json.dumps(labels)
db.session.commit()
def get_history_prompt_messages(self) -> List[PromptMessage]:

View File

@@ -31,6 +31,7 @@ import mimetypes
logger = logging.getLogger(__name__)
_builtin_providers = {}
_builtin_tools_labels = {}
class ToolManager:
@staticmethod
@@ -233,7 +234,7 @@ class ToolManager:
if len(_builtin_providers) > 0:
return list(_builtin_providers.values())
builtin_providers = []
builtin_providers: List[BuiltinToolProviderController] = []
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
if provider.startswith('__'):
continue
@@ -264,8 +265,30 @@ class ToolManager:
# cache the builtin providers
for provider in builtin_providers:
_builtin_providers[provider.identity.name] = provider
for tool in provider.get_tools():
_builtin_tools_labels[tool.identity.name] = tool.identity.label
return builtin_providers
@staticmethod
def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
"""
get the tool label
:param tool_name: the name of the tool
:return: the label of the tool
"""
global _builtin_tools_labels
if len(_builtin_tools_labels) == 0:
# init the builtin providers
ToolManager.list_builtin_providers()
if tool_name not in _builtin_tools_labels:
return None
return _builtin_tools_labels[tool_name]
@staticmethod
def user_list_providers(
user_id: str,