feat(api): maintain assistant content parts and file handling in advanced chat (#24663)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Generator, Mapping
|
from collections.abc import Callable, Generator, Mapping
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@@ -373,7 +374,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
) -> Generator[StreamResponse, None, None]:
|
) -> Generator[StreamResponse, None, None]:
|
||||||
"""Handle node succeeded events."""
|
"""Handle node succeeded events."""
|
||||||
# Record files if it's an answer node or end node
|
# Record files if it's an answer node or end node
|
||||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]:
|
||||||
self._recorded_files.extend(
|
self._recorded_files.extend(
|
||||||
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
|
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
|
||||||
)
|
)
|
||||||
@@ -896,7 +897,14 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
|
|
||||||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
message.answer = self._task_state.answer
|
|
||||||
|
# If there are assistant files, remove markdown image links from answer
|
||||||
|
answer_text = self._task_state.answer
|
||||||
|
if self._recorded_files:
|
||||||
|
# Remove markdown image links since we're storing files separately
|
||||||
|
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
|
||||||
|
|
||||||
|
message.answer = answer_text
|
||||||
message.updated_at = naive_utc_now()
|
message.updated_at = naive_utc_now()
|
||||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||||
|
@@ -31,6 +31,65 @@ class TokenBufferMemory:
|
|||||||
self.conversation = conversation
|
self.conversation = conversation
|
||||||
self.model_instance = model_instance
|
self.model_instance = model_instance
|
||||||
|
|
||||||
|
def _build_prompt_message_with_files(
|
||||||
|
self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
|
||||||
|
) -> PromptMessage:
|
||||||
|
"""
|
||||||
|
Build prompt message with files.
|
||||||
|
:param message_files: list of MessageFile objects
|
||||||
|
:param text_content: text content of the message
|
||||||
|
:param message: Message object
|
||||||
|
:param app_record: app record
|
||||||
|
:param is_user_message: whether this is a user message
|
||||||
|
:return: PromptMessage
|
||||||
|
"""
|
||||||
|
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||||
|
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||||
|
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||||
|
workflow_run = db.session.scalar(select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id))
|
||||||
|
if not workflow_run:
|
||||||
|
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||||
|
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||||
|
if not workflow:
|
||||||
|
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||||
|
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||||
|
|
||||||
|
detail = ImagePromptMessageContent.DETAIL.HIGH
|
||||||
|
if file_extra_config and app_record:
|
||||||
|
# Build files directly without filtering by belongs_to
|
||||||
|
file_objs = [
|
||||||
|
file_factory.build_from_message_file(
|
||||||
|
message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||||
|
)
|
||||||
|
for message_file in message_files
|
||||||
|
]
|
||||||
|
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||||
|
detail = file_extra_config.image_config.detail
|
||||||
|
else:
|
||||||
|
file_objs = []
|
||||||
|
|
||||||
|
if not file_objs:
|
||||||
|
if is_user_message:
|
||||||
|
return UserPromptMessage(content=text_content)
|
||||||
|
else:
|
||||||
|
return AssistantPromptMessage(content=text_content)
|
||||||
|
else:
|
||||||
|
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||||
|
for file in file_objs:
|
||||||
|
prompt_message = file_manager.to_prompt_message_content(
|
||||||
|
file,
|
||||||
|
image_detail_config=detail,
|
||||||
|
)
|
||||||
|
prompt_message_contents.append(prompt_message)
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=text_content))
|
||||||
|
|
||||||
|
if is_user_message:
|
||||||
|
return UserPromptMessage(content=prompt_message_contents)
|
||||||
|
else:
|
||||||
|
return AssistantPromptMessage(content=prompt_message_contents)
|
||||||
|
|
||||||
def get_history_prompt_messages(
|
def get_history_prompt_messages(
|
||||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||||
) -> Sequence[PromptMessage]:
|
) -> Sequence[PromptMessage]:
|
||||||
@@ -67,51 +126,45 @@ class TokenBufferMemory:
|
|||||||
|
|
||||||
prompt_messages: list[PromptMessage] = []
|
prompt_messages: list[PromptMessage] = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
|
# Process user message with files
|
||||||
if files:
|
user_files = (
|
||||||
file_extra_config = None
|
db.session.query(MessageFile)
|
||||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
.where(
|
||||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
MessageFile.message_id == message.id,
|
||||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
|
||||||
workflow_run = db.session.scalar(
|
|
||||||
select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
|
|
||||||
)
|
)
|
||||||
if not workflow_run:
|
.all()
|
||||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
|
||||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
|
||||||
if not workflow:
|
|
||||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
|
||||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
|
||||||
else:
|
|
||||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
|
||||||
|
|
||||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
|
||||||
if file_extra_config and app_record:
|
|
||||||
file_objs = file_factory.build_from_message_files(
|
|
||||||
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
|
||||||
)
|
)
|
||||||
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
|
||||||
detail = file_extra_config.image_config.detail
|
|
||||||
else:
|
|
||||||
file_objs = []
|
|
||||||
|
|
||||||
if not file_objs:
|
if user_files:
|
||||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
user_prompt_message = self._build_prompt_message_with_files(
|
||||||
else:
|
message_files=user_files,
|
||||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
text_content=message.query,
|
||||||
for file in file_objs:
|
message=message,
|
||||||
prompt_message = file_manager.to_prompt_message_content(
|
app_record=app_record,
|
||||||
file,
|
is_user_message=True,
|
||||||
image_detail_config=detail,
|
|
||||||
)
|
)
|
||||||
prompt_message_contents.append(prompt_message)
|
prompt_messages.append(user_prompt_message)
|
||||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||||
|
|
||||||
|
# Process assistant message with files
|
||||||
|
assistant_files = (
|
||||||
|
db.session.query(MessageFile)
|
||||||
|
.where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
if assistant_files:
|
||||||
|
assistant_prompt_message = self._build_prompt_message_with_files(
|
||||||
|
message_files=assistant_files,
|
||||||
|
text_content=message.answer,
|
||||||
|
message=message,
|
||||||
|
app_record=app_record,
|
||||||
|
is_user_message=False,
|
||||||
|
)
|
||||||
|
prompt_messages.append(assistant_prompt_message)
|
||||||
|
else:
|
||||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||||
|
|
||||||
if not prompt_messages:
|
if not prompt_messages:
|
||||||
|
@@ -41,8 +41,14 @@ def build_from_message_file(
|
|||||||
"url": message_file.url,
|
"url": message_file.url,
|
||||||
"id": message_file.id,
|
"id": message_file.id,
|
||||||
"type": message_file.type,
|
"type": message_file.type,
|
||||||
"upload_file_id": message_file.upload_file_id,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Set the correct ID field based on transfer method
|
||||||
|
if message_file.transfer_method == FileTransferMethod.TOOL_FILE.value:
|
||||||
|
mapping["tool_file_id"] = message_file.upload_file_id
|
||||||
|
else:
|
||||||
|
mapping["upload_file_id"] = message_file.upload_file_id
|
||||||
|
|
||||||
return build_from_mapping(
|
return build_from_mapping(
|
||||||
mapping=mapping,
|
mapping=mapping,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@@ -318,6 +324,11 @@ def _is_file_valid_with_config(
|
|||||||
file_transfer_method: FileTransferMethod,
|
file_transfer_method: FileTransferMethod,
|
||||||
config: FileUploadConfig,
|
config: FileUploadConfig,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
# FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model)
|
||||||
|
# These are internally generated and should bypass user upload restrictions
|
||||||
|
if file_transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
|
return True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
config.allowed_file_types
|
config.allowed_file_types
|
||||||
and input_file_type not in config.allowed_file_types
|
and input_file_type not in config.allowed_file_types
|
||||||
|
Reference in New Issue
Block a user