feat(api): Add image multimodal support for LLMNode (#17372)
Enhance `LLMNode` with multimodal capability, introducing support for image outputs. This implementation extracts base64-encoded images from LLM responses, saves them to the storage service, and records the file metadata in the `ToolFile` table. In conversations, these images are rendered as markdown-based inline images. Additionally, the images are included in the LLMNode's output as file variables, enabling subsequent nodes in the workflow to utilize them. To integrate file outputs into workflows, adjustments to the frontend code are necessary. For multimodal output functionality, updates to related model configurations are required. Currently, this capability has been applied exclusively to Google's Gemini models. Close #15814. Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
192
api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py
Normal file
192
api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import uuid
|
||||
from typing import NamedTuple
|
||||
from unittest import mock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.file import FileTransferMethod, FileType, models
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools import signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.nodes.llm.file_saver import (
|
||||
FileSaverImpl,
|
||||
_extract_content_type_and_extension,
|
||||
_get_extension,
|
||||
_validate_extension_override,
|
||||
)
|
||||
from models import ToolFile
|
||||
|
||||
_PNG_DATA = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
|
||||
def _gen_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestFileSaverImpl:
|
||||
def test_save_binary_string(self, monkeypatch):
|
||||
user_id = _gen_id()
|
||||
tenant_id = _gen_id()
|
||||
file_type = FileType.IMAGE
|
||||
mime_type = "image/png"
|
||||
mock_signed_url = "https://example.com/image.png"
|
||||
mock_tool_file = ToolFile(
|
||||
id=_gen_id(),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_key="test-file-key",
|
||||
mimetype=mime_type,
|
||||
original_url=None,
|
||||
name=f"{_gen_id()}.png",
|
||||
size=len(_PNG_DATA),
|
||||
)
|
||||
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
|
||||
mocked_engine = mock.MagicMock(spec=Engine)
|
||||
|
||||
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
|
||||
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
|
||||
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
|
||||
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
|
||||
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
|
||||
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
|
||||
mocked_sign_file.return_value = mock_signed_url
|
||||
|
||||
storage_file_manager = FileSaverImpl(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
engine_factory=mocked_engine,
|
||||
)
|
||||
|
||||
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
|
||||
assert file.tenant_id == tenant_id
|
||||
assert file.type == file_type
|
||||
assert file.transfer_method == FileTransferMethod.TOOL_FILE
|
||||
assert file.extension == ".png"
|
||||
assert file.mime_type == mime_type
|
||||
assert file.size == len(_PNG_DATA)
|
||||
assert file.related_id == mock_tool_file.id
|
||||
|
||||
assert file.generate_url() == mock_signed_url
|
||||
|
||||
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=_PNG_DATA,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
|
||||
|
||||
def test_save_remote_url_request_failed(self, monkeypatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
mock_request = httpx.Request("GET", _TEST_URL)
|
||||
mock_response = httpx.Response(
|
||||
status_code=401,
|
||||
request=mock_request,
|
||||
)
|
||||
file_saver = FileSaverImpl(
|
||||
user_id=_gen_id(),
|
||||
tenant_id=_gen_id(),
|
||||
)
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc:
|
||||
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||
mock_get.assert_called_once_with(_TEST_URL)
|
||||
assert exc.value.response.status_code == 401
|
||||
|
||||
def test_save_remote_url_success(self, monkeypatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
mime_type = "image/png"
|
||||
user_id = _gen_id()
|
||||
tenant_id = _gen_id()
|
||||
|
||||
mock_request = httpx.Request("GET", _TEST_URL)
|
||||
mock_response = httpx.Response(
|
||||
status_code=200,
|
||||
content=b"test-data",
|
||||
headers={"Content-Type": mime_type},
|
||||
request=mock_request,
|
||||
)
|
||||
|
||||
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
|
||||
mock_tool_file = ToolFile(
|
||||
id=_gen_id(),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_key="test-file-key",
|
||||
mimetype=mime_type,
|
||||
original_url=None,
|
||||
name=f"{_gen_id()}.png",
|
||||
size=len(_PNG_DATA),
|
||||
)
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
|
||||
monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string)
|
||||
|
||||
file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||
mock_save_binary_string.assert_called_once_with(
|
||||
mock_response.content,
|
||||
mime_type,
|
||||
FileType.IMAGE,
|
||||
extension_override=".png",
|
||||
)
|
||||
assert file == mock_tool_file
|
||||
|
||||
|
||||
def test_validate_extension_override():
|
||||
class TestCase(NamedTuple):
|
||||
extension_override: str | None
|
||||
expected: str | None
|
||||
|
||||
cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"]
|
||||
|
||||
for valid_ext_override in [None, "", ".png", ".tar.gz"]:
|
||||
assert valid_ext_override == _validate_extension_override(valid_ext_override)
|
||||
|
||||
for invalid_ext_override in ["png", "tar.gz"]:
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_validate_extension_override(invalid_ext_override)
|
||||
|
||||
|
||||
class TestExtractContentTypeAndExtension:
|
||||
def test_with_both_content_type_and_extension(self):
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_url_with_file_extension(self):
|
||||
for content_type in [None, ""]:
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_response_with_content_type(self):
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_no_content_type_and_no_extension(self):
|
||||
for content_type in [None, ""]:
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
|
||||
assert content_type == "application/octet-stream"
|
||||
assert extension == ".bin"
|
||||
|
||||
|
||||
class TestGetExtension:
|
||||
def test_with_extension_override(self):
|
||||
mime_type = "image/png"
|
||||
for override in [".jpg", ""]:
|
||||
extension = _get_extension(mime_type, override)
|
||||
assert extension == override
|
||||
|
||||
def test_without_extension_override(self):
|
||||
mime_type = "image/png"
|
||||
extension = _get_extension(mime_type)
|
||||
assert extension == ".png"
|
@@ -1,5 +1,8 @@
|
||||
import base64
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -30,6 +33,7 @@ from core.workflow.nodes.llm.entities import (
|
||||
VisionConfig,
|
||||
VisionConfigOptions,
|
||||
)
|
||||
from core.workflow.nodes.llm.file_saver import LLMFileSaver
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from models.enums import UserFrom
|
||||
from models.provider import ProviderType
|
||||
@@ -49,8 +53,8 @@ class MockTokenBufferMemory:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_node():
|
||||
data = LLMNodeData(
|
||||
def llm_node_data() -> LLMNodeData:
|
||||
return LLMNodeData(
|
||||
title="Test LLM",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
prompt_template=[],
|
||||
@@ -64,42 +68,65 @@ def llm_node():
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_init_params() -> GraphInitParams:
|
||||
return GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph() -> Graph:
|
||||
return Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
return GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_node(
|
||||
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
|
||||
) -> LLMNode:
|
||||
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
"data": llm_node_data.model_dump(),
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
return node
|
||||
|
||||
@@ -465,3 +492,167 @@ def test_handle_list_messages_basic(llm_node):
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_node_for_multimodal(
|
||||
llm_node_data, graph_init_params, graph, graph_runtime_state
|
||||
) -> tuple[LLMNode, LLMFileSaver]:
|
||||
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
},
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
return node, mock_file_saver
|
||||
|
||||
|
||||
class TestLLMNodeSaveMultiModalImageOutput:
|
||||
def test_llm_node_save_inline_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data=base64.b64encode(b"test-data").decode(),
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_file = File(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=str(uuid.uuid4()),
|
||||
filename="test-file.png",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=9,
|
||||
)
|
||||
mock_file_saver.save_binary_string.return_value = mock_file
|
||||
file = llm_node._save_multimodal_image_output(content=content)
|
||||
assert llm_node._file_outputs == [mock_file]
|
||||
assert file == mock_file
|
||||
mock_file_saver.save_binary_string.assert_called_once_with(
|
||||
data=b"test-data", mime_type="image/png", file_type=FileType.IMAGE
|
||||
)
|
||||
|
||||
def test_llm_node_save_url_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/jpg",
|
||||
)
|
||||
mock_file = File(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=str(uuid.uuid4()),
|
||||
filename="test-file.png",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=9,
|
||||
)
|
||||
mock_file_saver.save_remote_url.return_value = mock_file
|
||||
file = llm_node._save_multimodal_image_output(content=content)
|
||||
assert llm_node._file_outputs == [mock_file]
|
||||
assert file == mock_file
|
||||
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
|
||||
|
||||
|
||||
def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
|
||||
mock_file = mock.MagicMock(spec=File)
|
||||
mock_file.generate_url.return_value = "https://example.com/image.png"
|
||||
markdown = llm_node._image_file_to_markdown(mock_file)
|
||||
assert markdown == ""
|
||||
|
||||
|
||||
class TestSaveMultimodalOutputAndConvertResultToMarkdown:
|
||||
def test_str_content(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
|
||||
assert list(gen) == ["hello world"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
def test_text_prompt_message_content(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
[TextPromptMessageContent(data="hello world")]
|
||||
)
|
||||
assert list(gen) == ["hello world"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
def test_image_content_with_inline_data(self, llm_node_for_multimodal, monkeypatch):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
|
||||
image_raw_data = b"PNG_DATA"
|
||||
image_b64_data = base64.b64encode(image_raw_data).decode()
|
||||
|
||||
mock_saved_file = File(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename="test.png",
|
||||
extension=".png",
|
||||
size=len(image_raw_data),
|
||||
related_id=str(uuid.uuid4()),
|
||||
url="https://example.com/test.png",
|
||||
storage_key="test_storage_key",
|
||||
)
|
||||
mock_file_saver.save_binary_string.return_value = mock_saved_file
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
[
|
||||
ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data=image_b64_data,
|
||||
mime_type="image/png",
|
||||
)
|
||||
]
|
||||
)
|
||||
yielded_strs = list(gen)
|
||||
assert len(yielded_strs) == 1
|
||||
|
||||
# This assertion requires careful handling.
|
||||
# `FILES_URL` settings can vary across environments, which might lead to fragile tests.
|
||||
#
|
||||
# Rather than asserting the complete URL returned by _save_multimodal_output_and_convert_result_to_markdown,
|
||||
# we verify that the result includes the markdown image syntax and the expected file URL path.
|
||||
expected_file_url_path = f"/files/tools/{mock_saved_file.related_id}.png"
|
||||
assert yielded_strs[0].startswith("
|
||||
assert expected_file_url_path in yielded_strs[0]
|
||||
assert yielded_strs[0].endswith(")")
|
||||
mock_file_saver.save_binary_string.assert_called_once_with(
|
||||
data=image_raw_data,
|
||||
mime_type="image/png",
|
||||
file_type=FileType.IMAGE,
|
||||
)
|
||||
assert mock_saved_file in llm_node._file_outputs
|
||||
|
||||
def test_unknown_content_type(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
|
||||
assert list(gen) == ["frozenset({'hello world'})"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
def test_unknown_item_type(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
|
||||
assert list(gen) == ["frozenset({'hello world'})"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
def test_none_content(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
|
||||
assert list(gen) == []
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
Reference in New Issue
Block a user