Refactor agent history organization and initialization of agent scrat… (#2495)

This commit is contained in:
Yeuoly
2024-02-20 19:03:43 +08:00
committed by GitHub
parent e6cd7b0467
commit ae3ad59b16
2 changed files with 98 additions and 14 deletions

View File

@@ -12,6 +12,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -39,6 +40,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
self._repack_app_orchestration_config(app_orchestration_config)
agent_scratchpad: list[AgentScratchpadUnit] = []
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
# check model mode
if self.app_orchestration_config.model_config.mode == "completion":
@@ -327,6 +329,39 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
continue
return instruction
def _init_agent_scratchpad(self,
agent_scratchpad: list[AgentScratchpadUnit],
messages: list[PromptMessage]
) -> list[AgentScratchpadUnit]:
"""
init agent scratchpad
"""
current_scratchpad: AgentScratchpadUnit = None
for message in messages:
if isinstance(message, AssistantPromptMessage):
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content,
action_str='',
action=None,
observation=None
)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments)
)
except:
pass
agent_scratchpad.append(current_scratchpad)
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
current_scratchpad.observation = message.content
return agent_scratchpad
def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
"""