feat: agent node add memory (#15976)

This commit is contained in:
Novice
2025-04-03 16:40:58 +08:00
committed by GitHub
parent 3d76f09c3a
commit dcdec98c8e
7 changed files with 116 additions and 20 deletions

View File

@@ -1,15 +1,18 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from typing import Any, Optional, cast
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager
from core.variables.segments import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
@@ -19,7 +22,9 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories.agent_factory import get_plugin_agent_strategy
from models.model import Conversation
from models.workflow import WorkflowNodeExecutionStatus
@@ -233,17 +238,20 @@ class AgentNode(ToolNode):
value = tool_value
if parameter.type == "model-selector":
value = cast(dict[str, Any], value)
model_instance = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
provider=value.get("provider", ""),
model_type=ModelType(value.get("model_type", "")),
model=value.get("model", ""),
)
models = model_instance.model_type_instance.plugin_model_provider.declaration.models
finded_model = next((model for model in models if model.model == value.get("model", "")), None)
value["entity"] = finded_model.model_dump(mode="json") if finded_model else None
model_instance, model_schema = self._fetch_model(value)
# memory config
history_prompt_messages = []
if node_data.memory:
memory = self._fetch_memory(model_instance)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size if node_data.memory.window.size else None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
value["history_prompt_messages"] = history_prompt_messages
value["entity"] = model_schema.model_dump(mode="json") if model_schema else None
result[parameter_name] = value
return result
@@ -297,3 +305,46 @@ class AgentNode(ToolNode):
except StopIteration:
icon = None
return icon
def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]
)
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
# get conversation
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
)
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM, model=model_name
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_instance, model_schema

View File

@@ -3,6 +3,7 @@ from typing import Any, Literal, Union
from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from core.workflow.nodes.base.entities import BaseNodeData
@@ -11,6 +12,7 @@ class AgentNodeData(BaseNodeData):
agent_strategy_provider_name: str # redundancy
agent_strategy_name: str
agent_strategy_label: str # redundancy
memory: MemoryConfig | None = None
class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]