Feat/Agent-Image-Processing (#3293)

Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
Yeuoly
2024-04-10 14:48:40 +08:00
committed by GitHub
parent 240c793e7a
commit 14bb0b02ac
6 changed files with 148 additions and 40 deletions

View File

@@ -1,6 +1,7 @@
import json
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any, Union
from core.agent.base_agent_runner import BaseAgentRunner
@@ -10,20 +11,21 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Conversation, Message, MessageAgentThought
from models.model import Message
logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, conversation: Conversation,
message: Message,
def run(self, message: Message,
query: str,
) -> Generator[LLMResultChunk, None, None]:
"""
@@ -35,11 +37,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
prompt_template = app_config.prompt_template.simple_prompt_template or ''
prompt_messages = self.history_prompt_messages
prompt_messages = self.organize_prompt_messages(
prompt_template=prompt_template,
query=query,
prompt_messages=prompt_messages
)
prompt_messages = self._init_system_message(prompt_template, prompt_messages)
prompt_messages = self._organize_user_query(query, prompt_messages)
# convert tools into ModelRuntime Tool format
prompt_messages_tools: list[PromptMessageTool] = []
@@ -68,7 +67,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# continue to run until there is not any tool call
function_call_state = True
agent_thoughts: list[MessageAgentThought] = []
llm_usage = {
'usage': None
}
@@ -287,9 +285,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
}
tool_responses.append(tool_response)
prompt_messages = self.organize_prompt_messages(
prompt_template=prompt_template,
query=None,
prompt_messages = self._organize_assistant_message(
tool_call_id=tool_call_id,
tool_call_name=tool_call_name,
tool_response=tool_response['tool_response'],
@@ -324,6 +320,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
iteration_step += 1
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
@@ -386,29 +384,68 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls
def organize_prompt_messages(self, prompt_template: str,
query: str = None,
tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
prompt_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize prompt messages
Initialize system message
"""
if not prompt_messages:
prompt_messages = [
if not prompt_messages and prompt_template:
return [
SystemPromptMessage(content=prompt_template),
UserPromptMessage(content=query),
]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
if tool_response:
prompt_messages = prompt_messages.copy()
prompt_messages.append(
ToolPromptMessage(
content=tool_response,
tool_call_id=tool_call_id,
name=tool_call_name,
)
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize assistant message
"""
prompt_messages = deepcopy(prompt_messages)
if tool_response is not None:
prompt_messages.append(
ToolPromptMessage(
content=tool_response,
tool_call_id=tool_call_id,
name=tool_call_name,
)
)
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
We need to remove the image messages from the prompt messages at the first iteration.
"""
prompt_messages = deepcopy(prompt_messages)
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = '\n'.join([
content.data if content.type == PromptMessageContentType.TEXT else
'[image]' if content.type == PromptMessageContentType.IMAGE else
'[file]'
for content in prompt_message.content
])
return prompt_messages