Feat/chat message image first for agent and advanced_chat APP (#23796)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -512,7 +512,6 @@ class BaseAgentRunner(AppRunner):
|
||||
if not file_objs:
|
||||
return UserPromptMessage(content=message.query)
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file in file_objs:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
@@ -520,4 +519,6 @@ class BaseAgentRunner(AppRunner):
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
|
||||
return UserPromptMessage(content=prompt_message_contents)
|
||||
|
@@ -39,9 +39,6 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
@@ -52,6 +49,8 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
@@ -59,6 +58,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
|
@@ -395,9 +395,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
@@ -408,6 +405,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
@@ -415,6 +414,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
|
@@ -125,11 +125,11 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
if files:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
@@ -196,16 +196,17 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
query = parser.format(prompt_inputs)
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
if memory and memory_config:
|
||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||
|
||||
if files and query is not None:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
@@ -215,27 +216,27 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
last_message = prompt_messages[-1] if prompt_messages else None
|
||||
if last_message and last_message.role == PromptMessageRole.USER:
|
||||
# get last user message content and add files
|
||||
prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))]
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=cast(str, last_message.content)))
|
||||
|
||||
last_message.content = prompt_message_contents
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=""))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
elif query:
|
||||
|
@@ -265,11 +265,11 @@ class SimplePromptTransform(PromptTransform):
|
||||
) -> UserPromptMessage:
|
||||
if files:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||
|
||||
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
||||
else:
|
||||
|
@@ -164,7 +164,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
)
|
||||
assert isinstance(prompt_messages[3].content, list)
|
||||
assert len(prompt_messages[3].content) == 2
|
||||
assert prompt_messages[3].content[1].data == files[0].remote_url
|
||||
assert prompt_messages[3].content[0].data == files[0].remote_url
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
Reference in New Issue
Block a user