feat: Allow using file variables directly in the LLM node and support more file types. (#10679)

Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
-LAN-
2024-11-22 16:30:22 +08:00
committed by GitHub
parent 535c72cad7
commit c5f7d650b5
36 changed files with 1033 additions and 265 deletions

View File

@@ -39,7 +39,14 @@ class VisionConfig(BaseModel):
class PromptConfig(BaseModel):
jinja2_variables: Optional[list[VariableSelector]] = None
jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list)
@field_validator("jinja2_variables", mode="before")
@classmethod
def convert_none_jinja2_variables(cls, v: Any):
if v is None:
return []
return v
class LLMNodeChatModelMessage(ChatModelMessage):
@@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: Optional[PromptConfig] = None
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
@field_validator("prompt_config", mode="before")
@classmethod
def convert_none_prompt_config(cls, v: Any):
if v is None:
return PromptConfig()
return v

View File

@@ -24,3 +24,11 @@ class LLMModeRequiredError(LLMNodeError):
class NoPromptFoundError(LLMNodeError):
"""Raised when no prompt is found in the LLM configuration."""
class NotSupportedPromptTypeError(LLMNodeError):
"""Raised when the prompt type is not supported."""
class MemoryRolePrefixRequiredError(LLMNodeError):
"""Raised when memory role prefix is required for completion model."""

View File

@@ -1,4 +1,5 @@
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
@@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
AudioPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
VideoPromptMessageContent,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables import (
@@ -32,8 +38,9 @@ from core.variables import (
ObjectSegment,
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode
@@ -62,14 +69,18 @@ from .exc import (
InvalidVariableTypeError,
LLMModeRequiredError,
LLMNodeError,
MemoryRolePrefixRequiredError,
ModelNotExistError,
NoPromptFoundError,
NotSupportedPromptTypeError,
VariableNotFoundError,
)
if TYPE_CHECKING:
from core.file.models import File
logger = logging.getLogger(__name__)
class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
@@ -123,17 +134,13 @@ class LLMNode(BaseNode[LLMNodeData]):
# fetch prompt messages
if self.node_data.memory:
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not query:
raise VariableNotFoundError("Query not found")
query = query.text
query = self.node_data.memory.query_prompt_template
else:
query = None
prompt_messages, stop = self._fetch_prompt_messages(
system_query=query,
inputs=inputs,
files=files,
user_query=query,
user_files=files,
context=context,
memory=memory,
model_config=model_config,
@@ -141,6 +148,8 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
)
process_data = {
@@ -181,6 +190,17 @@ class LLMNode(BaseNode[LLMNodeData]):
)
)
return
except Exception as e:
logger.exception(f"Node {self.node_id} failed to run")
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data,
)
)
return
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
@@ -203,8 +223,8 @@ class LLMNode(BaseNode[LLMNodeData]):
self,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
) -> Generator[NodeEvent, None, None]:
db.session.close()
@@ -519,9 +539,8 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_prompt_messages(
self,
*,
system_query: str | None = None,
inputs: dict[str, str] | None = None,
files: Sequence["File"],
user_query: str | None = None,
user_files: Sequence["File"],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
@@ -529,58 +548,146 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config: MemoryConfig | None = None,
vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
inputs = inputs or {}
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages = []
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs=inputs,
query=system_query or "",
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config,
)
stop = model_config.stop
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
_handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)
# Get memory messages for chat mode
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if user_query:
message = LLMNodeChatModelMessage(
text=user_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
_handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
# For completion model
prompt_messages.extend(
_handle_completion_template(
template=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
)
# Get memory text for completion model
memory_text = _handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
# Add current query to the prompt message
if user_query:
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
prompt_messages[0].content = prompt_content
else:
errmsg = f"Prompt type {type(prompt_template)} is not supported"
logger.warning(errmsg)
raise NotSupportedPromptTypeError(errmsg)
if vision_enabled and user_files:
file_prompts = []
for file in user_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Filter prompt messages
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if prompt_message.is_empty():
continue
if not isinstance(prompt_message.content, str):
if isinstance(prompt_message.content, list):
prompt_message_content = []
for content_item in prompt_message.content or []:
# Skip image if vision is disabled
if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE:
for content_item in prompt_message.content:
# Skip content if features are not defined
if not model_config.model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
prompt_message_content.append(content_item)
continue
if isinstance(content_item, ImagePromptMessageContent):
# Override vision config if LLM node has vision config,
# cuz vision detail is related to the configuration from FileUpload feature.
content_item.detail = vision_detail
prompt_message_content.append(content_item)
elif isinstance(
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
# Skip content if corresponding feature is not supported
if (
(
content_item.type == PromptMessageContentType.IMAGE
and ModelFeature.VISION not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_config.model_schema.features
)
):
prompt_message_content.append(content_item)
if len(prompt_message_content) > 1:
prompt_message.content = prompt_message_content
elif (
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
):
continue
prompt_message_content.append(content_item)
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
prompt_message.content = prompt_message_content[0].data
else:
prompt_message.content = prompt_message_content
if prompt_message.is_empty():
continue
filtered_prompt_messages.append(prompt_message)
if not filtered_prompt_messages:
if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
stop = model_config.stop
return filtered_prompt_messages, stop
@classmethod
@@ -715,3 +822,198 @@ class LLMNode(BaseNode[LLMNodeData]):
}
},
}
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
raise NotImplementedError(f"Role {role} is not supported")
def _render_jinja2_message(
*,
template: str,
jinjia2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
if not template:
return ""
jinjia2_inputs = {}
for jinja2_variable in jinjia2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=jinjia2_inputs,
)
result_text = code_execute_resp["result"]
return result_text
def _handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = UserPromptMessage(content=file_contents)
prompt_messages.append(prompt_message)
return prompt_messages
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _handle_memory_chat_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
memory_messages = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
return memory_messages
def _handle_memory_completion_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = memory.get_history_prompt_text(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:
"""Handle completion template processing outside of LLMNode class.
Args:
template: The completion model prompt template
context: Optional context string
jinja2_variables: Variables for jinja2 template rendering
variable_pool: Variable pool for template conversion
Returns:
Sequence of prompt messages
"""
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
else:
if context:
template_text = template.text.replace("{#context#}", context)
else:
template_text = template.text
result_text = variable_pool.convert_template(template_text).text
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
prompt_messages.append(prompt_message)
return prompt_messages

View File

@@ -86,12 +86,14 @@ class QuestionClassifierNode(LLMNode):
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
system_query=query,
user_query=query,
memory=memory,
model_config=model_config,
files=files,
user_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
)
# handle invoke result