Refactor/message cycle manage and knowledge retrieval (#20460)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-05-30 14:36:44 +08:00
committed by GitHub
parent 5a991295e0
commit a6ea15e63c
14 changed files with 222 additions and 181 deletions

View File

@@ -1,4 +1,3 @@
import json
import logging
import time
from collections.abc import Generator, Mapping
@@ -57,10 +56,9 @@ from core.app.entities.task_entities import (
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey
@@ -141,7 +139,7 @@ class AdvancedChatAppGenerateTaskPipeline:
)
self._task_state = WorkflowTaskState()
self._message_cycle_manager = MessageCycleManage(
self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity, task_state=self._task_state
)
@@ -162,7 +160,7 @@ class AdvancedChatAppGenerateTaskPipeline:
:return:
"""
# start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query
)
@@ -605,22 +603,18 @@ class AdvancedChatAppGenerateTaskPipeline:
yield self._message_end_to_stream_response()
break
elif isinstance(event, QueueRetrieverResourcesEvent):
self._message_cycle_manager._handle_retriever_resources(event)
self._message_cycle_manager.handle_retriever_resources(event)
with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()
session.commit()
elif isinstance(event, QueueAnnotationReplyEvent):
self._message_cycle_manager._handle_annotation_reply(event)
self._message_cycle_manager.handle_annotation_reply(event)
with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()
session.commit()
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
@@ -637,12 +631,12 @@ class AdvancedChatAppGenerateTaskPipeline:
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._message_cycle_manager._message_to_stream_response(
yield self._message_cycle_manager.message_to_stream_response(
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_cycle_manager._message_replace_to_stream_response(
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=event.text, reason=event.reason
)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
@@ -654,7 +648,7 @@ class AdvancedChatAppGenerateTaskPipeline:
)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_cycle_manager._message_replace_to_stream_response(
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=output_moderation_answer,
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
)
@@ -683,9 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline:
message = self._get_message(session=session)
message.answer = self._task_state.answer
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()
message_files = [
MessageFile(
message_id=message.id,
@@ -713,9 +705,9 @@ class AdvancedChatAppGenerateTaskPipeline:
message.answer_price_unit = usage.completion_price_unit
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.metadata["usage"] = jsonable_encoder(usage)
self._task_state.metadata.usage = usage
else:
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage())
self._task_state.metadata.usage = LLMUsage.empty_usage()
message_was_created.send(
message,
application_generate_entity=self._application_generate_entity,
@@ -726,18 +718,16 @@ class AdvancedChatAppGenerateTaskPipeline:
Message end to stream response.
:return:
"""
extras = {}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.copy()
extras = self._task_state.metadata.model_dump()
if "annotation_reply" in extras["metadata"]:
del extras["metadata"]["annotation_reply"]
if self._task_state.metadata.annotation_reply:
del extras["annotation_reply"]
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message_id,
files=self._recorded_files,
metadata=extras.get("metadata", {}),
metadata=extras,
)
def _handle_output_moderation_chunk(self, text: str) -> bool:

View File

@@ -50,7 +50,6 @@ from core.app.entities.task_entities import (
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
@@ -130,9 +129,7 @@ class WorkflowAppGenerateTaskPipeline:
)
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._task_state = WorkflowTaskState()
self._workflow_run_id = ""
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@@ -543,7 +540,6 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector
)

View File

@@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Optional
@@ -6,6 +6,7 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
@@ -283,7 +284,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict]
retriever_resources: Sequence[RetrievalSourceMetadata]
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None

View File

@@ -2,20 +2,37 @@ from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class AnnotationReplyAccount(BaseModel):
id: str
name: str
class AnnotationReply(BaseModel):
id: str
account: AnnotationReplyAccount
class TaskStateMetadata(BaseModel):
annotation_reply: AnnotationReply | None = None
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list)
usage: LLMUsage | None = None
class TaskState(BaseModel):
"""
TaskState entity
"""
metadata: dict = {}
metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata)
class EasyUITaskState(TaskState):

View File

@@ -1,4 +1,3 @@
import json
import logging
import time
from collections.abc import Generator
@@ -43,7 +42,7 @@ from core.app.entities.task_entities import (
StreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -51,7 +50,6 @@ from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@@ -63,7 +61,7 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage):
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
"""
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
@@ -104,6 +102,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
)
)
self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity,
task_state=self._task_state,
)
self._conversation_name_generate_thread: Optional[Thread] = None
def process(
@@ -115,7 +118,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
]:
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
)
@@ -136,9 +139,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata
extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation_mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse(
@@ -277,7 +280,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
)
if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=output_moderation_answer
)
with Session(db.engine) as session:
# Save message
@@ -286,9 +291,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_end_resp = self._message_end_to_stream_response()
yield message_end_resp
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
self._message_cycle_manager.handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent):
annotation = self._handle_annotation_reply(event)
annotation = self._message_cycle_manager.handle_annotation_reply(event)
if annotation:
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent):
@@ -296,7 +301,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if agent_thought_response is not None:
yield agent_thought_response
elif isinstance(event, QueueMessageFileEvent):
response = self._message_file_to_stream_response(event)
response = self._message_cycle_manager.message_file_to_stream_response(event)
if response:
yield response
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
@@ -318,7 +323,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
yield self._message_to_stream_response(
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
)
@@ -328,7 +333,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_id=self._message_id,
)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_replace_to_stream_response(answer=event.text)
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
else:
@@ -372,9 +377,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message.provider_response_latency = time.perf_counter() - self._start_at
message.total_price = usage.total_price
message.currency = usage.currency
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()
if trace_manager:
trace_manager.add_trace_task(
@@ -423,16 +426,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
Message end to stream response.
:return:
"""
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
extras = {}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata
self._task_state.metadata.usage = self._task_state.llm_result.usage
metadata_dict = self._task_state.metadata.model_dump()
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message_id,
metadata=extras.get("metadata", {}),
metadata=metadata_dict,
)
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:

View File

@@ -17,6 +17,8 @@ from core.app.entities.queue_entities import (
QueueRetrieverResourcesEvent,
)
from core.app.entities.task_entities import (
AnnotationReply,
AnnotationReplyAccount,
EasyUITaskState,
MessageFileStreamResponse,
MessageReplaceStreamResponse,
@@ -30,7 +32,7 @@ from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
from services.annotation_service import AppAnnotationService
class MessageCycleManage:
class MessageCycleManager:
def __init__(
self,
*,
@@ -45,7 +47,7 @@ class MessageCycleManage:
self._application_generate_entity = application_generate_entity
self._task_state = task_state
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
"""
Generate conversation name.
:param conversation_id: conversation id
@@ -102,7 +104,7 @@ class MessageCycleManage:
db.session.commit()
db.session.close()
def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
"""
Handle annotation reply.
:param event: event
@@ -111,25 +113,28 @@ class MessageCycleManage:
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation:
account = annotation.account
self._task_state.metadata["annotation_reply"] = {
"id": annotation.id,
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
}
self._task_state.metadata.annotation_reply = AnnotationReply(
id=annotation.id,
account=AnnotationReplyAccount(
id=annotation.account_id,
name=account.name if account else "Dify user",
),
)
return annotation
return None
def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
"""
Handle retriever resources.
:param event: event
:return:
"""
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata["retriever_resources"] = event.retriever_resources
self._task_state.metadata.retriever_resources = event.retriever_resources
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
"""
Message file to stream response.
:param event: event
@@ -166,7 +171,7 @@ class MessageCycleManage:
return None
def _message_to_stream_response(
def message_to_stream_response(
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
) -> MessageStreamResponse:
"""
@@ -182,7 +187,7 @@ class MessageCycleManage:
from_variable_selector=from_variable_selector,
)
def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
"""
Message replace to stream response.
:param answer: answer