feat(llm_node): support order in text and files (#11837)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2024-12-20 14:12:50 +08:00
committed by GitHub
parent 3599751f93
commit 996a9135f6
18 changed files with 217 additions and 175 deletions

View File

@@ -136,6 +136,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="",
)
]

View File

@@ -1,34 +1,9 @@
import json
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType, FileUploadConfig
from core.file import File, FileTransferMethod, FileType, FileUploadConfig
from models.workflow import Workflow
def test_file_loads_and_dumps():
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
)
file_dict = file.model_dump()
assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY
assert file_dict["type"] == file.type.value
assert isinstance(file_dict["type"], str)
assert file_dict["transfer_method"] == file.transfer_method.value
assert isinstance(file_dict["transfer_method"], str)
assert "_extra_config" not in file_dict
file_obj = File.model_validate(file_dict)
assert file_obj.id == file.id
assert file_obj.tenant_id == file.tenant_id
assert file_obj.type == file.type
assert file_obj.transfer_method == file.transfer_method
assert file_obj.remote_url == file.remote_url
def test_file_to_dict():
file = File(
id="file1",
@@ -36,10 +11,11 @@ def test_file_to_dict():
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="storage_key",
)
file_dict = file.to_dict()
assert "_extra_config" not in file_dict
assert "_storage_key" not in file_dict
assert "url" in file_dict

View File

@@ -51,6 +51,7 @@ def test_http_request_node_binary_file(monkeypatch):
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1111",
storage_key="",
),
),
)
@@ -138,6 +139,7 @@ def test_http_request_node_form_with_file(monkeypatch):
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1111",
storage_key="",
),
),
)

View File

@@ -21,7 +21,8 @@ from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment, StringSegment
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
@@ -157,6 +158,7 @@ def test_fetch_files_with_file_segment(llm_node):
filename="test.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
storage_key="",
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
@@ -173,6 +175,7 @@ def test_fetch_files_with_array_file_segment(llm_node):
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
storage_key="",
),
File(
id="2",
@@ -181,6 +184,7 @@ def test_fetch_files_with_array_file_segment(llm_node):
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
storage_key="",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
@@ -224,14 +228,15 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
storage_key="",
)
]
fake_query = faker.sentence()
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=fake_query,
user_files=files,
sys_query=fake_query,
sys_files=files,
context=None,
memory=None,
model_config=model_config,
@@ -283,8 +288,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
test_scenarios = [
LLMNodeTestScenario(
description="No files",
user_query=fake_query,
user_files=[],
sys_query=fake_query,
sys_files=[],
features=[],
vision_enabled=False,
vision_detail=None,
@@ -318,8 +323,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
),
LLMNodeTestScenario(
description="User files",
user_query=fake_query,
user_files=[
sys_query=fake_query,
sys_files=[
File(
tenant_id="test",
type=FileType.IMAGE,
@@ -328,6 +333,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
storage_key="",
)
],
vision_enabled=True,
@@ -370,8 +376,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
),
LLMNodeTestScenario(
description="Prompt template with variable selector of File",
user_query=fake_query,
user_files=[],
sys_query=fake_query,
sys_files=[],
vision_enabled=False,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
@@ -403,6 +409,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
storage_key="",
)
},
),
@@ -417,8 +424,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Call the method under test
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=scenario.user_query,
user_files=scenario.user_files,
sys_query=scenario.sys_query,
sys_files=scenario.sys_files,
context=fake_context,
memory=memory,
model_config=model_config,
@@ -435,3 +442,29 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
assert (
prompt_messages == scenario.expected_messages
), f"Message content mismatch in scenario: {scenario.description}"
def test_handle_list_messages_basic(llm_node):
messages = [
LLMNodeChatModelMessage(
text="Hello, {#context#}",
role=PromptMessageRole.USER,
edition_type="basic",
)
]
context = "world"
jinja2_variables = []
variable_pool = llm_node.graph_runtime_state.variable_pool
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
result = llm_node._handle_list_messages(
messages=messages,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail_config,
)
assert len(result) == 1
assert isinstance(result[0], UserPromptMessage)
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]

View File

@@ -12,8 +12,8 @@ class LLMNodeTestScenario(BaseModel):
"""Test scenario for LLM node testing."""
description: str = Field(..., description="Description of the test scenario")
user_query: str = Field(..., description="User query input")
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
sys_query: str = Field(..., description="User query input")
sys_files: Sequence[File] = Field(default_factory=list, description="List of user files")
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")

View File

@@ -248,6 +248,7 @@ def test_array_file_contains_file_name():
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
filename="ab",
storage_key="",
),
],
)

View File

@@ -57,6 +57,7 @@ def test_filter_files_by_type(list_operator_node):
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related1",
storage_key="",
),
File(
filename="document1.pdf",
@@ -64,6 +65,7 @@ def test_filter_files_by_type(list_operator_node):
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related2",
storage_key="",
),
File(
filename="image2.png",
@@ -71,6 +73,7 @@ def test_filter_files_by_type(list_operator_node):
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related3",
storage_key="",
),
File(
filename="audio1.mp3",
@@ -78,6 +81,7 @@ def test_filter_files_by_type(list_operator_node):
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related4",
storage_key="",
),
]
variable = ArrayFileSegment(value=files)
@@ -130,6 +134,7 @@ def test_get_file_extract_string_func():
mime_type="text/plain",
remote_url="https://example.com/test_file.txt",
related_id="test_related_id",
storage_key="",
)
# Test each case
@@ -150,6 +155,7 @@ def test_get_file_extract_string_func():
mime_type=None,
remote_url=None,
related_id="test_related_id",
storage_key="",
)
assert _get_file_extract_string_func(key="name")(empty_file) == ""

View File

@@ -19,6 +19,7 @@ def file():
related_id="test_related_id",
remote_url="test_url",
filename="test_file.txt",
storage_key="",
)