Refactor/message cycle manage and knowledge retrieval (#20460)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes import NodeType
|
||||
@@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
|
@@ -1,8 +1,10 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
|
||||
@@ -17,7 +19,7 @@ class RunStreamChunkEvent(BaseModel):
|
||||
|
||||
|
||||
class RunRetrieverResourceEvent(BaseModel):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
|
@@ -43,6 +43,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.variables import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
@@ -474,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
|
||||
elif isinstance(context_value_variable, ArraySegment):
|
||||
context_str = ""
|
||||
original_retriever_resource = []
|
||||
original_retriever_resource: list[RetrievalSourceMetadata] = []
|
||||
for item in context_value_variable.value:
|
||||
if isinstance(item, str):
|
||||
context_str += item + "\n"
|
||||
@@ -492,7 +493,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
retriever_resources=original_retriever_resource, context=context_str.strip()
|
||||
)
|
||||
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict):
|
||||
if (
|
||||
"metadata" in context_dict
|
||||
and "_source" in context_dict["metadata"]
|
||||
@@ -500,24 +501,24 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
):
|
||||
metadata = context_dict.get("metadata", {})
|
||||
|
||||
source = {
|
||||
"position": metadata.get("position"),
|
||||
"dataset_id": metadata.get("dataset_id"),
|
||||
"dataset_name": metadata.get("dataset_name"),
|
||||
"document_id": metadata.get("document_id"),
|
||||
"document_name": metadata.get("document_name"),
|
||||
"data_source_type": metadata.get("data_source_type"),
|
||||
"segment_id": metadata.get("segment_id"),
|
||||
"retriever_from": metadata.get("retriever_from"),
|
||||
"score": metadata.get("score"),
|
||||
"hit_count": metadata.get("segment_hit_count"),
|
||||
"word_count": metadata.get("segment_word_count"),
|
||||
"segment_position": metadata.get("segment_position"),
|
||||
"index_node_hash": metadata.get("segment_index_node_hash"),
|
||||
"content": context_dict.get("content"),
|
||||
"page": metadata.get("page"),
|
||||
"doc_metadata": metadata.get("doc_metadata"),
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
position=metadata.get("position"),
|
||||
dataset_id=metadata.get("dataset_id"),
|
||||
dataset_name=metadata.get("dataset_name"),
|
||||
document_id=metadata.get("document_id"),
|
||||
document_name=metadata.get("document_name"),
|
||||
data_source_type=metadata.get("data_source_type"),
|
||||
segment_id=metadata.get("segment_id"),
|
||||
retriever_from=metadata.get("retriever_from"),
|
||||
score=metadata.get("score"),
|
||||
hit_count=metadata.get("segment_hit_count"),
|
||||
word_count=metadata.get("segment_word_count"),
|
||||
segment_position=metadata.get("segment_position"),
|
||||
index_node_hash=metadata.get("segment_index_node_hash"),
|
||||
content=context_dict.get("content"),
|
||||
page=metadata.get("page"),
|
||||
doc_metadata=metadata.get("doc_metadata"),
|
||||
)
|
||||
|
||||
return source
|
||||
|
||||
|
Reference in New Issue
Block a user