feat(llm_node): support order in text and files (#11837)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -136,6 +136,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
storage_key="",
|
||||
)
|
||||
]
|
||||
|
||||
|
@@ -1,34 +1,9 @@
|
||||
import json
|
||||
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType, FileUploadConfig
|
||||
from core.file import File, FileTransferMethod, FileType, FileUploadConfig
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
def test_file_loads_and_dumps():
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
)
|
||||
|
||||
file_dict = file.model_dump()
|
||||
assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY
|
||||
assert file_dict["type"] == file.type.value
|
||||
assert isinstance(file_dict["type"], str)
|
||||
assert file_dict["transfer_method"] == file.transfer_method.value
|
||||
assert isinstance(file_dict["transfer_method"], str)
|
||||
assert "_extra_config" not in file_dict
|
||||
|
||||
file_obj = File.model_validate(file_dict)
|
||||
assert file_obj.id == file.id
|
||||
assert file_obj.tenant_id == file.tenant_id
|
||||
assert file_obj.type == file.type
|
||||
assert file_obj.transfer_method == file.transfer_method
|
||||
assert file_obj.remote_url == file.remote_url
|
||||
|
||||
|
||||
def test_file_to_dict():
|
||||
file = File(
|
||||
id="file1",
|
||||
@@ -36,10 +11,11 @@ def test_file_to_dict():
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
storage_key="storage_key",
|
||||
)
|
||||
|
||||
file_dict = file.to_dict()
|
||||
assert "_extra_config" not in file_dict
|
||||
assert "_storage_key" not in file_dict
|
||||
assert "url" in file_dict
|
||||
|
||||
|
||||
|
@@ -51,6 +51,7 @@ def test_http_request_node_binary_file(monkeypatch):
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1111",
|
||||
storage_key="",
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -138,6 +139,7 @@ def test_http_request_node_form_with_file(monkeypatch):
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1111",
|
||||
storage_key="",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@@ -21,7 +21,8 @@ from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||
@@ -157,6 +158,7 @@ def test_fetch_files_with_file_segment(llm_node):
|
||||
filename="test.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
storage_key="",
|
||||
)
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
||||
|
||||
@@ -173,6 +175,7 @@ def test_fetch_files_with_array_file_segment(llm_node):
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
storage_key="",
|
||||
),
|
||||
File(
|
||||
id="2",
|
||||
@@ -181,6 +184,7 @@ def test_fetch_files_with_array_file_segment(llm_node):
|
||||
filename="test2.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="2",
|
||||
storage_key="",
|
||||
),
|
||||
]
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||
@@ -224,14 +228,15 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
storage_key="",
|
||||
)
|
||||
]
|
||||
|
||||
fake_query = faker.sentence()
|
||||
|
||||
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||
user_query=fake_query,
|
||||
user_files=files,
|
||||
sys_query=fake_query,
|
||||
sys_files=files,
|
||||
context=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
@@ -283,8 +288,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
test_scenarios = [
|
||||
LLMNodeTestScenario(
|
||||
description="No files",
|
||||
user_query=fake_query,
|
||||
user_files=[],
|
||||
sys_query=fake_query,
|
||||
sys_files=[],
|
||||
features=[],
|
||||
vision_enabled=False,
|
||||
vision_detail=None,
|
||||
@@ -318,8 +323,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
),
|
||||
LLMNodeTestScenario(
|
||||
description="User files",
|
||||
user_query=fake_query,
|
||||
user_files=[
|
||||
sys_query=fake_query,
|
||||
sys_files=[
|
||||
File(
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
@@ -328,6 +333,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
remote_url=fake_remote_url,
|
||||
extension=".jpg",
|
||||
mime_type="image/jpg",
|
||||
storage_key="",
|
||||
)
|
||||
],
|
||||
vision_enabled=True,
|
||||
@@ -370,8 +376,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
),
|
||||
LLMNodeTestScenario(
|
||||
description="Prompt template with variable selector of File",
|
||||
user_query=fake_query,
|
||||
user_files=[],
|
||||
sys_query=fake_query,
|
||||
sys_files=[],
|
||||
vision_enabled=False,
|
||||
vision_detail=fake_vision_detail,
|
||||
features=[ModelFeature.VISION],
|
||||
@@ -403,6 +409,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
remote_url=fake_remote_url,
|
||||
extension=".jpg",
|
||||
mime_type="image/jpg",
|
||||
storage_key="",
|
||||
)
|
||||
},
|
||||
),
|
||||
@@ -417,8 +424,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
|
||||
# Call the method under test
|
||||
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||
user_query=scenario.user_query,
|
||||
user_files=scenario.user_files,
|
||||
sys_query=scenario.sys_query,
|
||||
sys_files=scenario.sys_files,
|
||||
context=fake_context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
@@ -435,3 +442,29 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
assert (
|
||||
prompt_messages == scenario.expected_messages
|
||||
), f"Message content mismatch in scenario: {scenario.description}"
|
||||
|
||||
|
||||
def test_handle_list_messages_basic(llm_node):
|
||||
messages = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="Hello, {#context#}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
]
|
||||
context = "world"
|
||||
jinja2_variables = []
|
||||
variable_pool = llm_node.graph_runtime_state.variable_pool
|
||||
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
result = llm_node._handle_list_messages(
|
||||
messages=messages,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail_config,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
|
||||
|
@@ -12,8 +12,8 @@ class LLMNodeTestScenario(BaseModel):
|
||||
"""Test scenario for LLM node testing."""
|
||||
|
||||
description: str = Field(..., description="Description of the test scenario")
|
||||
user_query: str = Field(..., description="User query input")
|
||||
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
|
||||
sys_query: str = Field(..., description="User query input")
|
||||
sys_files: Sequence[File] = Field(default_factory=list, description="List of user files")
|
||||
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
|
||||
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
|
||||
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
|
||||
|
@@ -248,6 +248,7 @@ def test_array_file_contains_file_name():
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
filename="ab",
|
||||
storage_key="",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@@ -57,6 +57,7 @@ def test_filter_files_by_type(list_operator_node):
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related1",
|
||||
storage_key="",
|
||||
),
|
||||
File(
|
||||
filename="document1.pdf",
|
||||
@@ -64,6 +65,7 @@ def test_filter_files_by_type(list_operator_node):
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related2",
|
||||
storage_key="",
|
||||
),
|
||||
File(
|
||||
filename="image2.png",
|
||||
@@ -71,6 +73,7 @@ def test_filter_files_by_type(list_operator_node):
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related3",
|
||||
storage_key="",
|
||||
),
|
||||
File(
|
||||
filename="audio1.mp3",
|
||||
@@ -78,6 +81,7 @@ def test_filter_files_by_type(list_operator_node):
|
||||
tenant_id="tenant1",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related4",
|
||||
storage_key="",
|
||||
),
|
||||
]
|
||||
variable = ArrayFileSegment(value=files)
|
||||
@@ -130,6 +134,7 @@ def test_get_file_extract_string_func():
|
||||
mime_type="text/plain",
|
||||
remote_url="https://example.com/test_file.txt",
|
||||
related_id="test_related_id",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
# Test each case
|
||||
@@ -150,6 +155,7 @@ def test_get_file_extract_string_func():
|
||||
mime_type=None,
|
||||
remote_url=None,
|
||||
related_id="test_related_id",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
assert _get_file_extract_string_func(key="name")(empty_file) == ""
|
||||
|
@@ -19,6 +19,7 @@ def file():
|
||||
related_id="test_related_id",
|
||||
remote_url="test_url",
|
||||
filename="test_file.txt",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user