feat(llm_node): support order in text and files (#11837)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2024-12-20 14:12:50 +08:00
committed by GitHub
parent 3599751f93
commit 996a9135f6
18 changed files with 217 additions and 175 deletions

View File

@@ -50,6 +50,7 @@ class PromptConfig(BaseModel):
class LLMNodeChatModelMessage(ChatModelMessage):
text: str = ""
jinja2_text: Optional[str] = None

View File

@@ -145,8 +145,8 @@ class LLMNode(BaseNode[LLMNodeData]):
query = query_variable.text
prompt_messages, stop = self._fetch_prompt_messages(
user_query=query,
user_files=files,
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
@@ -545,8 +545,8 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_prompt_messages(
self,
*,
user_query: str | None = None,
user_files: Sequence["File"],
sys_query: str | None = None,
sys_files: Sequence["File"],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
@@ -562,7 +562,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
_handle_list_messages(
self._handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
@@ -581,14 +581,14 @@ class LLMNode(BaseNode[LLMNodeData]):
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if user_query:
if sys_query:
message = LLMNodeChatModelMessage(
text=user_query,
text=sys_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
_handle_list_messages(
self._handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
@@ -635,24 +635,27 @@ class LLMNode(BaseNode[LLMNodeData]):
raise ValueError("Invalid prompt content type")
# Add current query to the prompt message
if user_query:
if sys_query:
if prompt_content_type == str:
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
content_item.data = user_query + "\n" + content_item.data
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
if vision_enabled and user_files:
# The sys_files will be deprecated later
if vision_enabled and sys_files:
file_prompts = []
for file in user_files:
for file in sys_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
@@ -662,7 +665,7 @@ class LLMNode(BaseNode[LLMNodeData]):
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Filter prompt messages
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
@@ -846,6 +849,58 @@ class LLMNode(BaseNode[LLMNodeData]):
},
}
def _handle_list_messages(
self,
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages: list[PromptMessage] = []
for message in messages:
contents: list[PromptMessageContent] = []
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
contents.append(TextPromptMessageContent(data=result_text))
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
contents.append(file_content)
elif isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
contents.append(file_content)
else:
plain_text = segment.markdown.strip()
if plain_text:
contents.append(TextPromptMessageContent(data=plain_text))
prompt_message = _combine_message_content_with_role(contents=contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role:
@@ -880,68 +935,6 @@ def _render_jinja2_message(
return result_text
def _handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:

View File

@@ -86,10 +86,10 @@ class QuestionClassifierNode(LLMNode):
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
user_query=query,
sys_query=query,
memory=memory,
model_config=model_config,
user_files=files,
sys_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,