feat(api): Add image multimodal support for LLMNode (#17372)

Enhance `LLMNode` with multimodal capability, introducing support for
image outputs.

This implementation extracts base64-encoded images from LLM responses,
saves them to the storage service, and records the file metadata in the
`ToolFile` table. In conversations, these images are rendered as
markdown-based inline images.
Additionally, the images are included in the LLMNode's output as
file variables, enabling subsequent nodes in the workflow to utilize them.

To integrate file outputs into workflows, adjustments to the frontend code
are necessary.

For multimodal output functionality, updates to related model configurations
are required. Currently, this capability has been applied exclusively to
Google's Gemini models.

Close #15814.

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
QuantumGhost
2025-04-30 17:28:02 +08:00
committed by GitHub
parent 6c9a9d344a
commit 349c3cf7b8
24 changed files with 971 additions and 191 deletions

View File

@@ -0,0 +1,192 @@
import uuid
from typing import NamedTuple
from unittest import mock
import httpx
import pytest
from sqlalchemy import Engine
from core.file import FileTransferMethod, FileType, models
from core.helper import ssrf_proxy
from core.tools import signature
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.nodes.llm.file_saver import (
FileSaverImpl,
_extract_content_type_and_extension,
_get_extension,
_validate_extension_override,
)
from models import ToolFile
_PNG_DATA = b"\x89PNG\r\n\x1a\n"
def _gen_id():
return str(uuid.uuid4())
class TestFileSaverImpl:
def test_save_binary_string(self, monkeypatch):
user_id = _gen_id()
tenant_id = _gen_id()
file_type = FileType.IMAGE
mime_type = "image/png"
mock_signed_url = "https://example.com/image.png"
mock_tool_file = ToolFile(
id=_gen_id(),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_key="test-file-key",
mimetype=mime_type,
original_url=None,
name=f"{_gen_id()}.png",
size=len(_PNG_DATA),
)
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
mocked_engine = mock.MagicMock(spec=Engine)
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
mocked_sign_file.return_value = mock_signed_url
storage_file_manager = FileSaverImpl(
user_id=user_id,
tenant_id=tenant_id,
engine_factory=mocked_engine,
)
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
assert file.tenant_id == tenant_id
assert file.type == file_type
assert file.transfer_method == FileTransferMethod.TOOL_FILE
assert file.extension == ".png"
assert file.mime_type == mime_type
assert file.size == len(_PNG_DATA)
assert file.related_id == mock_tool_file.id
assert file.generate_url() == mock_signed_url
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_binary=_PNG_DATA,
mimetype=mime_type,
)
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
def test_save_remote_url_request_failed(self, monkeypatch):
_TEST_URL = "https://example.com/image.png"
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
status_code=401,
request=mock_request,
)
file_saver = FileSaverImpl(
user_id=_gen_id(),
tenant_id=_gen_id(),
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
with pytest.raises(httpx.HTTPStatusError) as exc:
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
mock_get.assert_called_once_with(_TEST_URL)
assert exc.value.response.status_code == 401
def test_save_remote_url_success(self, monkeypatch):
_TEST_URL = "https://example.com/image.png"
mime_type = "image/png"
user_id = _gen_id()
tenant_id = _gen_id()
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
status_code=200,
content=b"test-data",
headers={"Content-Type": mime_type},
request=mock_request,
)
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
mock_tool_file = ToolFile(
id=_gen_id(),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_key="test-file-key",
mimetype=mime_type,
original_url=None,
name=f"{_gen_id()}.png",
size=len(_PNG_DATA),
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string)
file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
mock_save_binary_string.assert_called_once_with(
mock_response.content,
mime_type,
FileType.IMAGE,
extension_override=".png",
)
assert file == mock_tool_file
def test_validate_extension_override():
class TestCase(NamedTuple):
extension_override: str | None
expected: str | None
cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"]
for valid_ext_override in [None, "", ".png", ".tar.gz"]:
assert valid_ext_override == _validate_extension_override(valid_ext_override)
for invalid_ext_override in ["png", "tar.gz"]:
with pytest.raises(ValueError) as exc:
_validate_extension_override(invalid_ext_override)
class TestExtractContentTypeAndExtension:
def test_with_both_content_type_and_extension(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_url_with_file_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
assert content_type == "image/png"
assert extension == ".png"
def test_response_with_content_type(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_no_content_type_and_no_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
assert content_type == "application/octet-stream"
assert extension == ".bin"
class TestGetExtension:
def test_with_extension_override(self):
mime_type = "image/png"
for override in [".jpg", ""]:
extension = _get_extension(mime_type, override)
assert extension == override
def test_without_extension_override(self):
mime_type = "image/png"
extension = _get_extension(mime_type)
assert extension == ".png"

View File

@@ -1,5 +1,8 @@
import base64
import uuid
from collections.abc import Sequence
from typing import Optional
from unittest import mock
import pytest
@@ -30,6 +33,7 @@ from core.workflow.nodes.llm.entities import (
VisionConfig,
VisionConfigOptions,
)
from core.workflow.nodes.llm.file_saver import LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom
from models.provider import ProviderType
@@ -49,8 +53,8 @@ class MockTokenBufferMemory:
@pytest.fixture
def llm_node():
data = LLMNodeData(
def llm_node_data() -> LLMNodeData:
return LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
@@ -64,42 +68,65 @@ def llm_node():
),
),
)
@pytest.fixture
def graph_init_params() -> GraphInitParams:
return GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
@pytest.fixture
def graph() -> Graph:
return Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
)
@pytest.fixture
def graph_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
return GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
)
@pytest.fixture
def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
node = LLMNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
"data": llm_node_data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
return node
@@ -465,3 +492,167 @@ def test_handle_list_messages_basic(llm_node):
assert len(result) == 1
assert isinstance(result[0], UserPromptMessage)
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
@pytest.fixture
def llm_node_for_multimodal(
llm_node_data, graph_init_params, graph, graph_runtime_state
) -> tuple[LLMNode, LLMFileSaver]:
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
node = LLMNode(
id="1",
config={
"id": "1",
"data": llm_node_data.model_dump(),
},
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
return node, mock_file_saver
class TestLLMNodeSaveMultiModalImageOutput:
def test_llm_node_save_inline_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
llm_node, mock_file_saver = llm_node_for_multimodal
content = ImagePromptMessageContent(
format="png",
base64_data=base64.b64encode(b"test-data").decode(),
mime_type="image/png",
)
mock_file = File(
id=str(uuid.uuid4()),
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=str(uuid.uuid4()),
filename="test-file.png",
extension=".png",
mime_type="image/png",
size=9,
)
mock_file_saver.save_binary_string.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_binary_string.assert_called_once_with(
data=b"test-data", mime_type="image/png", file_type=FileType.IMAGE
)
def test_llm_node_save_url_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
llm_node, mock_file_saver = llm_node_for_multimodal
content = ImagePromptMessageContent(
format="png",
url="https://example.com/image.png",
mime_type="image/jpg",
)
mock_file = File(
id=str(uuid.uuid4()),
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=str(uuid.uuid4()),
filename="test-file.png",
extension=".png",
mime_type="image/png",
size=9,
)
mock_file_saver.save_remote_url.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
mock_file = mock.MagicMock(spec=File)
mock_file.generate_url.return_value = "https://example.com/image.png"
markdown = llm_node._image_file_to_markdown(mock_file)
assert markdown == "![](https://example.com/image.png)"
class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_str_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_text_prompt_message_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[TextPromptMessageContent(data="hello world")]
)
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_image_content_with_inline_data(self, llm_node_for_multimodal, monkeypatch):
llm_node, mock_file_saver = llm_node_for_multimodal
image_raw_data = b"PNG_DATA"
image_b64_data = base64.b64encode(image_raw_data).decode()
mock_saved_file = File(
id=str(uuid.uuid4()),
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
filename="test.png",
extension=".png",
size=len(image_raw_data),
related_id=str(uuid.uuid4()),
url="https://example.com/test.png",
storage_key="test_storage_key",
)
mock_file_saver.save_binary_string.return_value = mock_saved_file
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[
ImagePromptMessageContent(
format="png",
base64_data=image_b64_data,
mime_type="image/png",
)
]
)
yielded_strs = list(gen)
assert len(yielded_strs) == 1
# This assertion requires careful handling.
# `FILES_URL` settings can vary across environments, which might lead to fragile tests.
#
# Rather than asserting the complete URL returned by _save_multimodal_output_and_convert_result_to_markdown,
# we verify that the result includes the markdown image syntax and the expected file URL path.
expected_file_url_path = f"/files/tools/{mock_saved_file.related_id}.png"
assert yielded_strs[0].startswith("![](")
assert expected_file_url_path in yielded_strs[0]
assert yielded_strs[0].endswith(")")
mock_file_saver.save_binary_string.assert_called_once_with(
data=image_raw_data,
mime_type="image/png",
file_type=FileType.IMAGE,
)
assert mock_saved_file in llm_node._file_outputs
def test_unknown_content_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_unknown_item_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_none_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
assert list(gen) == []
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()