|
|
|
@@ -1,4 +1,5 @@
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
from collections.abc import Generator, Mapping, Sequence
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
|
|
|
|
|
|
|
|
@@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
|
|
|
|
from core.entities.model_entities import ModelStatus
|
|
|
|
|
from core.entities.provider_entities import QuotaUnit
|
|
|
|
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
|
|
|
|
from core.file import FileType, file_manager
|
|
|
|
|
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
|
|
|
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
|
|
|
|
from core.model_manager import ModelInstance, ModelManager
|
|
|
|
|
from core.model_runtime.entities import (
|
|
|
|
|
AudioPromptMessageContent,
|
|
|
|
|
ImagePromptMessageContent,
|
|
|
|
|
PromptMessage,
|
|
|
|
|
PromptMessageContentType,
|
|
|
|
|
TextPromptMessageContent,
|
|
|
|
|
VideoPromptMessageContent,
|
|
|
|
|
)
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
|
|
|
|
from core.model_runtime.entities.model_entities import ModelType
|
|
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
|
|
AssistantPromptMessage,
|
|
|
|
|
PromptMessageRole,
|
|
|
|
|
SystemPromptMessage,
|
|
|
|
|
UserPromptMessage,
|
|
|
|
|
)
|
|
|
|
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
|
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
|
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
|
|
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|
|
|
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
|
|
|
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
|
|
|
|
from core.variables import (
|
|
|
|
@@ -32,8 +38,9 @@ from core.variables import (
|
|
|
|
|
ObjectSegment,
|
|
|
|
|
StringSegment,
|
|
|
|
|
)
|
|
|
|
|
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
|
|
|
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
|
|
|
|
from core.workflow.entities.variable_entities import VariableSelector
|
|
|
|
|
from core.workflow.entities.variable_pool import VariablePool
|
|
|
|
|
from core.workflow.enums import SystemVariableKey
|
|
|
|
|
from core.workflow.graph_engine.entities.event import InNodeEvent
|
|
|
|
|
from core.workflow.nodes.base import BaseNode
|
|
|
|
@@ -62,14 +69,18 @@ from .exc import (
|
|
|
|
|
InvalidVariableTypeError,
|
|
|
|
|
LLMModeRequiredError,
|
|
|
|
|
LLMNodeError,
|
|
|
|
|
MemoryRolePrefixRequiredError,
|
|
|
|
|
ModelNotExistError,
|
|
|
|
|
NoPromptFoundError,
|
|
|
|
|
NotSupportedPromptTypeError,
|
|
|
|
|
VariableNotFoundError,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from core.file.models import File
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
_node_data_cls = LLMNodeData
|
|
|
|
@@ -123,17 +134,13 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
|
|
|
|
|
# fetch prompt messages
|
|
|
|
|
if self.node_data.memory:
|
|
|
|
|
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
|
|
|
|
if not query:
|
|
|
|
|
raise VariableNotFoundError("Query not found")
|
|
|
|
|
query = query.text
|
|
|
|
|
query = self.node_data.memory.query_prompt_template
|
|
|
|
|
else:
|
|
|
|
|
query = None
|
|
|
|
|
|
|
|
|
|
prompt_messages, stop = self._fetch_prompt_messages(
|
|
|
|
|
system_query=query,
|
|
|
|
|
inputs=inputs,
|
|
|
|
|
files=files,
|
|
|
|
|
user_query=query,
|
|
|
|
|
user_files=files,
|
|
|
|
|
context=context,
|
|
|
|
|
memory=memory,
|
|
|
|
|
model_config=model_config,
|
|
|
|
@@ -141,6 +148,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
memory_config=self.node_data.memory,
|
|
|
|
|
vision_enabled=self.node_data.vision.enabled,
|
|
|
|
|
vision_detail=self.node_data.vision.configs.detail,
|
|
|
|
|
variable_pool=self.graph_runtime_state.variable_pool,
|
|
|
|
|
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
process_data = {
|
|
|
|
@@ -181,6 +190,17 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(f"Node {self.node_id} failed to run")
|
|
|
|
|
yield RunCompletedEvent(
|
|
|
|
|
run_result=NodeRunResult(
|
|
|
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
|
|
error=str(e),
|
|
|
|
|
inputs=node_inputs,
|
|
|
|
|
process_data=process_data,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
|
|
|
|
|
|
|
|
@@ -203,8 +223,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
self,
|
|
|
|
|
node_data_model: ModelConfig,
|
|
|
|
|
model_instance: ModelInstance,
|
|
|
|
|
prompt_messages: list[PromptMessage],
|
|
|
|
|
stop: Optional[list[str]] = None,
|
|
|
|
|
prompt_messages: Sequence[PromptMessage],
|
|
|
|
|
stop: Optional[Sequence[str]] = None,
|
|
|
|
|
) -> Generator[NodeEvent, None, None]:
|
|
|
|
|
db.session.close()
|
|
|
|
|
|
|
|
|
@@ -519,9 +539,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
def _fetch_prompt_messages(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
system_query: str | None = None,
|
|
|
|
|
inputs: dict[str, str] | None = None,
|
|
|
|
|
files: Sequence["File"],
|
|
|
|
|
user_query: str | None = None,
|
|
|
|
|
user_files: Sequence["File"],
|
|
|
|
|
context: str | None = None,
|
|
|
|
|
memory: TokenBufferMemory | None = None,
|
|
|
|
|
model_config: ModelConfigWithCredentialsEntity,
|
|
|
|
@@ -529,58 +548,146 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
memory_config: MemoryConfig | None = None,
|
|
|
|
|
vision_enabled: bool = False,
|
|
|
|
|
vision_detail: ImagePromptMessageContent.DETAIL,
|
|
|
|
|
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
|
|
|
|
inputs = inputs or {}
|
|
|
|
|
variable_pool: VariablePool,
|
|
|
|
|
jinja2_variables: Sequence[VariableSelector],
|
|
|
|
|
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
|
|
|
|
prompt_messages = []
|
|
|
|
|
|
|
|
|
|
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
|
|
|
|
prompt_messages = prompt_transform.get_prompt(
|
|
|
|
|
prompt_template=prompt_template,
|
|
|
|
|
inputs=inputs,
|
|
|
|
|
query=system_query or "",
|
|
|
|
|
files=files,
|
|
|
|
|
context=context,
|
|
|
|
|
memory_config=memory_config,
|
|
|
|
|
memory=memory,
|
|
|
|
|
model_config=model_config,
|
|
|
|
|
)
|
|
|
|
|
stop = model_config.stop
|
|
|
|
|
if isinstance(prompt_template, list):
|
|
|
|
|
# For chat model
|
|
|
|
|
prompt_messages.extend(
|
|
|
|
|
_handle_list_messages(
|
|
|
|
|
messages=prompt_template,
|
|
|
|
|
context=context,
|
|
|
|
|
jinja2_variables=jinja2_variables,
|
|
|
|
|
variable_pool=variable_pool,
|
|
|
|
|
vision_detail_config=vision_detail,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Get memory messages for chat mode
|
|
|
|
|
memory_messages = _handle_memory_chat_mode(
|
|
|
|
|
memory=memory,
|
|
|
|
|
memory_config=memory_config,
|
|
|
|
|
model_config=model_config,
|
|
|
|
|
)
|
|
|
|
|
# Extend prompt_messages with memory messages
|
|
|
|
|
prompt_messages.extend(memory_messages)
|
|
|
|
|
|
|
|
|
|
# Add current query to the prompt messages
|
|
|
|
|
if user_query:
|
|
|
|
|
message = LLMNodeChatModelMessage(
|
|
|
|
|
text=user_query,
|
|
|
|
|
role=PromptMessageRole.USER,
|
|
|
|
|
edition_type="basic",
|
|
|
|
|
)
|
|
|
|
|
prompt_messages.extend(
|
|
|
|
|
_handle_list_messages(
|
|
|
|
|
messages=[message],
|
|
|
|
|
context="",
|
|
|
|
|
jinja2_variables=[],
|
|
|
|
|
variable_pool=variable_pool,
|
|
|
|
|
vision_detail_config=vision_detail,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
|
|
|
|
# For completion model
|
|
|
|
|
prompt_messages.extend(
|
|
|
|
|
_handle_completion_template(
|
|
|
|
|
template=prompt_template,
|
|
|
|
|
context=context,
|
|
|
|
|
jinja2_variables=jinja2_variables,
|
|
|
|
|
variable_pool=variable_pool,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Get memory text for completion model
|
|
|
|
|
memory_text = _handle_memory_completion_mode(
|
|
|
|
|
memory=memory,
|
|
|
|
|
memory_config=memory_config,
|
|
|
|
|
model_config=model_config,
|
|
|
|
|
)
|
|
|
|
|
# Insert histories into the prompt
|
|
|
|
|
prompt_content = prompt_messages[0].content
|
|
|
|
|
if "#histories#" in prompt_content:
|
|
|
|
|
prompt_content = prompt_content.replace("#histories#", memory_text)
|
|
|
|
|
else:
|
|
|
|
|
prompt_content = memory_text + "\n" + prompt_content
|
|
|
|
|
prompt_messages[0].content = prompt_content
|
|
|
|
|
|
|
|
|
|
# Add current query to the prompt message
|
|
|
|
|
if user_query:
|
|
|
|
|
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
|
|
|
|
|
prompt_messages[0].content = prompt_content
|
|
|
|
|
else:
|
|
|
|
|
errmsg = f"Prompt type {type(prompt_template)} is not supported"
|
|
|
|
|
logger.warning(errmsg)
|
|
|
|
|
raise NotSupportedPromptTypeError(errmsg)
|
|
|
|
|
|
|
|
|
|
if vision_enabled and user_files:
|
|
|
|
|
file_prompts = []
|
|
|
|
|
for file in user_files:
|
|
|
|
|
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
|
|
|
|
file_prompts.append(file_prompt)
|
|
|
|
|
if (
|
|
|
|
|
len(prompt_messages) > 0
|
|
|
|
|
and isinstance(prompt_messages[-1], UserPromptMessage)
|
|
|
|
|
and isinstance(prompt_messages[-1].content, list)
|
|
|
|
|
):
|
|
|
|
|
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
|
|
|
|
|
else:
|
|
|
|
|
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
|
|
|
|
|
|
|
|
|
# Filter prompt messages
|
|
|
|
|
filtered_prompt_messages = []
|
|
|
|
|
for prompt_message in prompt_messages:
|
|
|
|
|
if prompt_message.is_empty():
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if not isinstance(prompt_message.content, str):
|
|
|
|
|
if isinstance(prompt_message.content, list):
|
|
|
|
|
prompt_message_content = []
|
|
|
|
|
for content_item in prompt_message.content or []:
|
|
|
|
|
# Skip image if vision is disabled
|
|
|
|
|
if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE:
|
|
|
|
|
for content_item in prompt_message.content:
|
|
|
|
|
# Skip content if features are not defined
|
|
|
|
|
if not model_config.model_schema.features:
|
|
|
|
|
if content_item.type != PromptMessageContentType.TEXT:
|
|
|
|
|
continue
|
|
|
|
|
prompt_message_content.append(content_item)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if isinstance(content_item, ImagePromptMessageContent):
|
|
|
|
|
# Override vision config if LLM node has vision config,
|
|
|
|
|
# cuz vision detail is related to the configuration from FileUpload feature.
|
|
|
|
|
content_item.detail = vision_detail
|
|
|
|
|
prompt_message_content.append(content_item)
|
|
|
|
|
elif isinstance(
|
|
|
|
|
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
|
|
|
|
|
# Skip content if corresponding feature is not supported
|
|
|
|
|
if (
|
|
|
|
|
(
|
|
|
|
|
content_item.type == PromptMessageContentType.IMAGE
|
|
|
|
|
and ModelFeature.VISION not in model_config.model_schema.features
|
|
|
|
|
)
|
|
|
|
|
or (
|
|
|
|
|
content_item.type == PromptMessageContentType.DOCUMENT
|
|
|
|
|
and ModelFeature.DOCUMENT not in model_config.model_schema.features
|
|
|
|
|
)
|
|
|
|
|
or (
|
|
|
|
|
content_item.type == PromptMessageContentType.VIDEO
|
|
|
|
|
and ModelFeature.VIDEO not in model_config.model_schema.features
|
|
|
|
|
)
|
|
|
|
|
or (
|
|
|
|
|
content_item.type == PromptMessageContentType.AUDIO
|
|
|
|
|
and ModelFeature.AUDIO not in model_config.model_schema.features
|
|
|
|
|
)
|
|
|
|
|
):
|
|
|
|
|
prompt_message_content.append(content_item)
|
|
|
|
|
|
|
|
|
|
if len(prompt_message_content) > 1:
|
|
|
|
|
prompt_message.content = prompt_message_content
|
|
|
|
|
elif (
|
|
|
|
|
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
|
|
|
|
|
):
|
|
|
|
|
continue
|
|
|
|
|
prompt_message_content.append(content_item)
|
|
|
|
|
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
|
|
|
|
prompt_message.content = prompt_message_content[0].data
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
prompt_message.content = prompt_message_content
|
|
|
|
|
if prompt_message.is_empty():
|
|
|
|
|
continue
|
|
|
|
|
filtered_prompt_messages.append(prompt_message)
|
|
|
|
|
|
|
|
|
|
if not filtered_prompt_messages:
|
|
|
|
|
if len(filtered_prompt_messages) == 0:
|
|
|
|
|
raise NoPromptFoundError(
|
|
|
|
|
"No prompt found in the LLM configuration. "
|
|
|
|
|
"Please ensure a prompt is properly configured before proceeding."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
stop = model_config.stop
|
|
|
|
|
return filtered_prompt_messages, stop
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@@ -715,3 +822,198 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
|
|
|
|
|
match role:
|
|
|
|
|
case PromptMessageRole.USER:
|
|
|
|
|
return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
|
|
|
|
|
case PromptMessageRole.ASSISTANT:
|
|
|
|
|
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
|
|
|
|
|
case PromptMessageRole.SYSTEM:
|
|
|
|
|
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
|
|
|
|
|
raise NotImplementedError(f"Role {role} is not supported")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _render_jinja2_message(
|
|
|
|
|
*,
|
|
|
|
|
template: str,
|
|
|
|
|
jinjia2_variables: Sequence[VariableSelector],
|
|
|
|
|
variable_pool: VariablePool,
|
|
|
|
|
):
|
|
|
|
|
if not template:
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
jinjia2_inputs = {}
|
|
|
|
|
for jinja2_variable in jinjia2_variables:
|
|
|
|
|
variable = variable_pool.get(jinja2_variable.value_selector)
|
|
|
|
|
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
|
|
|
|
code_execute_resp = CodeExecutor.execute_workflow_code_template(
|
|
|
|
|
language=CodeLanguage.JINJA2,
|
|
|
|
|
code=template,
|
|
|
|
|
inputs=jinjia2_inputs,
|
|
|
|
|
)
|
|
|
|
|
result_text = code_execute_resp["result"]
|
|
|
|
|
return result_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _handle_list_messages(
|
|
|
|
|
*,
|
|
|
|
|
messages: Sequence[LLMNodeChatModelMessage],
|
|
|
|
|
context: Optional[str],
|
|
|
|
|
jinja2_variables: Sequence[VariableSelector],
|
|
|
|
|
variable_pool: VariablePool,
|
|
|
|
|
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
|
|
|
|
) -> Sequence[PromptMessage]:
|
|
|
|
|
prompt_messages = []
|
|
|
|
|
for message in messages:
|
|
|
|
|
if message.edition_type == "jinja2":
|
|
|
|
|
result_text = _render_jinja2_message(
|
|
|
|
|
template=message.jinja2_text or "",
|
|
|
|
|
jinjia2_variables=jinja2_variables,
|
|
|
|
|
variable_pool=variable_pool,
|
|
|
|
|
)
|
|
|
|
|
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
|
|
|
|
|
prompt_messages.append(prompt_message)
|
|
|
|
|
else:
|
|
|
|
|
# Get segment group from basic message
|
|
|
|
|
if context:
|
|
|
|
|
template = message.text.replace("{#context#}", context)
|
|
|
|
|
else:
|
|
|
|
|
template = message.text
|
|
|
|
|
segment_group = variable_pool.convert_template(template)
|
|
|
|
|
|
|
|
|
|
# Process segments for images
|
|
|
|
|
file_contents = []
|
|
|
|
|
for segment in segment_group.value:
|
|
|
|
|
if isinstance(segment, ArrayFileSegment):
|
|
|
|
|
for file in segment.value:
|
|
|
|
|
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
|
|
|
|
file_content = file_manager.to_prompt_message_content(
|
|
|
|
|
file, image_detail_config=vision_detail_config
|
|
|
|
|
)
|
|
|
|
|
file_contents.append(file_content)
|
|
|
|
|
if isinstance(segment, FileSegment):
|
|
|
|
|
file = segment.value
|
|
|
|
|
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
|
|
|
|
file_content = file_manager.to_prompt_message_content(
|
|
|
|
|
file, image_detail_config=vision_detail_config
|
|
|
|
|
)
|
|
|
|
|
file_contents.append(file_content)
|
|
|
|
|
|
|
|
|
|
# Create message with text from all segments
|
|
|
|
|
plain_text = segment_group.text
|
|
|
|
|
if plain_text:
|
|
|
|
|
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
|
|
|
|
|
prompt_messages.append(prompt_message)
|
|
|
|
|
|
|
|
|
|
if file_contents:
|
|
|
|
|
# Create message with image contents
|
|
|
|
|
prompt_message = UserPromptMessage(content=file_contents)
|
|
|
|
|
prompt_messages.append(prompt_message)
|
|
|
|
|
|
|
|
|
|
return prompt_messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _calculate_rest_token(
|
|
|
|
|
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
|
|
|
|
) -> int:
|
|
|
|
|
rest_tokens = 2000
|
|
|
|
|
|
|
|
|
|
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
|
|
|
|
if model_context_tokens:
|
|
|
|
|
model_instance = ModelInstance(
|
|
|
|
|
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
|
|
|
|
|
|
|
|
|
max_tokens = 0
|
|
|
|
|
for parameter_rule in model_config.model_schema.parameter_rules:
|
|
|
|
|
if parameter_rule.name == "max_tokens" or (
|
|
|
|
|
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
|
|
|
|
):
|
|
|
|
|
max_tokens = (
|
|
|
|
|
model_config.parameters.get(parameter_rule.name)
|
|
|
|
|
or model_config.parameters.get(str(parameter_rule.use_template))
|
|
|
|
|
or 0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
|
|
|
|
rest_tokens = max(rest_tokens, 0)
|
|
|
|
|
|
|
|
|
|
return rest_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _handle_memory_chat_mode(
|
|
|
|
|
*,
|
|
|
|
|
memory: TokenBufferMemory | None,
|
|
|
|
|
memory_config: MemoryConfig | None,
|
|
|
|
|
model_config: ModelConfigWithCredentialsEntity,
|
|
|
|
|
) -> Sequence[PromptMessage]:
|
|
|
|
|
memory_messages = []
|
|
|
|
|
# Get messages from memory for chat model
|
|
|
|
|
if memory and memory_config:
|
|
|
|
|
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
|
|
|
|
memory_messages = memory.get_history_prompt_messages(
|
|
|
|
|
max_token_limit=rest_tokens,
|
|
|
|
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
|
|
|
|
)
|
|
|
|
|
return memory_messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _handle_memory_completion_mode(
|
|
|
|
|
*,
|
|
|
|
|
memory: TokenBufferMemory | None,
|
|
|
|
|
memory_config: MemoryConfig | None,
|
|
|
|
|
model_config: ModelConfigWithCredentialsEntity,
|
|
|
|
|
) -> str:
|
|
|
|
|
memory_text = ""
|
|
|
|
|
# Get history text from memory for completion model
|
|
|
|
|
if memory and memory_config:
|
|
|
|
|
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
|
|
|
|
if not memory_config.role_prefix:
|
|
|
|
|
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
|
|
|
|
memory_text = memory.get_history_prompt_text(
|
|
|
|
|
max_token_limit=rest_tokens,
|
|
|
|
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
|
|
|
|
human_prefix=memory_config.role_prefix.user,
|
|
|
|
|
ai_prefix=memory_config.role_prefix.assistant,
|
|
|
|
|
)
|
|
|
|
|
return memory_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _handle_completion_template(
|
|
|
|
|
*,
|
|
|
|
|
template: LLMNodeCompletionModelPromptTemplate,
|
|
|
|
|
context: Optional[str],
|
|
|
|
|
jinja2_variables: Sequence[VariableSelector],
|
|
|
|
|
variable_pool: VariablePool,
|
|
|
|
|
) -> Sequence[PromptMessage]:
|
|
|
|
|
"""Handle completion template processing outside of LLMNode class.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
template: The completion model prompt template
|
|
|
|
|
context: Optional context string
|
|
|
|
|
jinja2_variables: Variables for jinja2 template rendering
|
|
|
|
|
variable_pool: Variable pool for template conversion
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Sequence of prompt messages
|
|
|
|
|
"""
|
|
|
|
|
prompt_messages = []
|
|
|
|
|
if template.edition_type == "jinja2":
|
|
|
|
|
result_text = _render_jinja2_message(
|
|
|
|
|
template=template.jinja2_text or "",
|
|
|
|
|
jinjia2_variables=jinja2_variables,
|
|
|
|
|
variable_pool=variable_pool,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
if context:
|
|
|
|
|
template_text = template.text.replace("{#context#}", context)
|
|
|
|
|
else:
|
|
|
|
|
template_text = template.text
|
|
|
|
|
result_text = variable_pool.convert_template(template_text).text
|
|
|
|
|
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
|
|
|
|
|
prompt_messages.append(prompt_message)
|
|
|
|
|
return prompt_messages
|
|
|
|
|