Refactor/react agent (#3355)

This commit is contained in:
Yeuoly
2024-04-11 18:34:17 +08:00
committed by GitHub
parent 509c640a80
commit cea107b165
9 changed files with 589 additions and 511 deletions

View File

@@ -238,6 +238,34 @@ class BaseAgentRunner(AppRunner):
return prompt_tool
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
"""
Init tools
"""
tool_instances = {}
prompt_messages_tools = []
for tool in self.app_config.agent.tools if self.app_config.agent else []:
try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception:
# api tool may be deleted
continue
# save tool entity
tool_instances[tool.tool_name] = tool_entity
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# convert dataset tools into ModelRuntime Tool format
for dataset_tool in self.dataset_tools:
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# save tool entity
tool_instances[dataset_tool.identity.name] = dataset_tool
return tool_instances, prompt_messages_tools
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
"""
update prompt message tool
@@ -325,7 +353,7 @@ class BaseAgentRunner(AppRunner):
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, str],
observation: Union[str, dict],
tool_invoke_meta: Union[str, dict],
answer: str,
messages_ids: list[str],