Refactor/message cycle manage and knowledge retrieval (#20460)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-05-30 14:36:44 +08:00
committed by GitHub
parent 5a991295e0
commit a6ea15e63c
14 changed files with 222 additions and 181 deletions

View File

@@ -1,9 +1,10 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType
@@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent):
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources")
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")

View File

@@ -1,8 +1,10 @@
from collections.abc import Sequence
from datetime import datetime
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@@ -17,7 +19,7 @@ class RunStreamChunkEvent(BaseModel):
class RunRetrieverResourceEvent(BaseModel):
retriever_resources: list[dict] = Field(..., description="retriever resources")
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")

View File

@@ -43,6 +43,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.variables import (
ArrayAnySegment,
ArrayFileSegment,
@@ -474,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
elif isinstance(context_value_variable, ArraySegment):
context_str = ""
original_retriever_resource = []
original_retriever_resource: list[RetrievalSourceMetadata] = []
for item in context_value_variable.value:
if isinstance(item, str):
context_str += item + "\n"
@@ -492,7 +493,7 @@ class LLMNode(BaseNode[LLMNodeData]):
retriever_resources=original_retriever_resource, context=context_str.strip()
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
def _convert_to_original_retriever_resource(self, context_dict: dict):
if (
"metadata" in context_dict
and "_source" in context_dict["metadata"]
@@ -500,24 +501,24 @@ class LLMNode(BaseNode[LLMNodeData]):
):
metadata = context_dict.get("metadata", {})
source = {
"position": metadata.get("position"),
"dataset_id": metadata.get("dataset_id"),
"dataset_name": metadata.get("dataset_name"),
"document_id": metadata.get("document_id"),
"document_name": metadata.get("document_name"),
"data_source_type": metadata.get("data_source_type"),
"segment_id": metadata.get("segment_id"),
"retriever_from": metadata.get("retriever_from"),
"score": metadata.get("score"),
"hit_count": metadata.get("segment_hit_count"),
"word_count": metadata.get("segment_word_count"),
"segment_position": metadata.get("segment_position"),
"index_node_hash": metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"),
"page": metadata.get("page"),
"doc_metadata": metadata.get("doc_metadata"),
}
source = RetrievalSourceMetadata(
position=metadata.get("position"),
dataset_id=metadata.get("dataset_id"),
dataset_name=metadata.get("dataset_name"),
document_id=metadata.get("document_id"),
document_name=metadata.get("document_name"),
data_source_type=metadata.get("data_source_type"),
segment_id=metadata.get("segment_id"),
retriever_from=metadata.get("retriever_from"),
score=metadata.get("score"),
hit_count=metadata.get("segment_hit_count"),
word_count=metadata.get("segment_word_count"),
segment_position=metadata.get("segment_position"),
index_node_hash=metadata.get("segment_index_node_hash"),
content=context_dict.get("content"),
page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"),
)
return source