fix: Improve create_agent_thought and save_agent_thought Logic (#21263)

This commit is contained in:
Will
2025-07-27 11:06:37 +08:00
committed by GitHub
parent 665fcad655
commit 67a0751cf3
3 changed files with 37 additions and 41 deletions

View File

@@ -280,7 +280,7 @@ class BaseAgentRunner(AppRunner):
def create_agent_thought( def create_agent_thought(
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought: ) -> str:
""" """
Create agent thought Create agent thought
""" """
@@ -313,16 +313,15 @@ class BaseAgentRunner(AppRunner):
db.session.add(thought) db.session.add(thought)
db.session.commit() db.session.commit()
db.session.refresh(thought) agent_thought_id = str(thought.id)
self.agent_thought_count += 1
db.session.close() db.session.close()
self.agent_thought_count += 1 return agent_thought_id
return thought
def save_agent_thought( def save_agent_thought(
self, self,
agent_thought: MessageAgentThought, agent_thought_id: str,
tool_name: str | None, tool_name: str | None,
tool_input: Union[str, dict, None], tool_input: Union[str, dict, None],
thought: str | None, thought: str | None,
@@ -335,12 +334,9 @@ class BaseAgentRunner(AppRunner):
""" """
Save agent thought Save agent thought
""" """
updated_agent_thought = ( agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first() if not agent_thought:
)
if not updated_agent_thought:
raise ValueError("agent thought not found") raise ValueError("agent thought not found")
agent_thought = updated_agent_thought
if thought: if thought:
agent_thought.thought += thought agent_thought.thought += thought
@@ -355,7 +351,7 @@ class BaseAgentRunner(AppRunner):
except Exception: except Exception:
tool_input = json.dumps(tool_input) tool_input = json.dumps(tool_input)
updated_agent_thought.tool_input = tool_input agent_thought.tool_input = tool_input
if observation: if observation:
if isinstance(observation, dict): if isinstance(observation, dict):
@@ -364,27 +360,27 @@ class BaseAgentRunner(AppRunner):
except Exception: except Exception:
observation = json.dumps(observation) observation = json.dumps(observation)
updated_agent_thought.observation = observation agent_thought.observation = observation
if answer: if answer:
agent_thought.answer = answer agent_thought.answer = answer
if messages_ids is not None and len(messages_ids) > 0: if messages_ids is not None and len(messages_ids) > 0:
updated_agent_thought.message_files = json.dumps(messages_ids) agent_thought.message_files = json.dumps(messages_ids)
if llm_usage: if llm_usage:
updated_agent_thought.message_token = llm_usage.prompt_tokens agent_thought.message_token = llm_usage.prompt_tokens
updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit agent_thought.message_price_unit = llm_usage.prompt_price_unit
updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price agent_thought.message_unit_price = llm_usage.prompt_unit_price
updated_agent_thought.answer_token = llm_usage.completion_tokens agent_thought.answer_token = llm_usage.completion_tokens
updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit agent_thought.answer_price_unit = llm_usage.completion_price_unit
updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price agent_thought.answer_unit_price = llm_usage.completion_unit_price
updated_agent_thought.tokens = llm_usage.total_tokens agent_thought.tokens = llm_usage.total_tokens
updated_agent_thought.total_price = llm_usage.total_price agent_thought.total_price = llm_usage.total_price
# check if tool labels is not empty # check if tool labels is not empty
labels = updated_agent_thought.tool_labels or {} labels = agent_thought.tool_labels or {}
tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else [] tools = agent_thought.tool.split(";") if agent_thought.tool else []
for tool in tools: for tool in tools:
if not tool: if not tool:
continue continue
@@ -395,7 +391,7 @@ class BaseAgentRunner(AppRunner):
else: else:
labels[tool] = {"en_US": tool, "zh_Hans": tool} labels[tool] = {"en_US": tool, "zh_Hans": tool}
updated_agent_thought.tool_labels_str = json.dumps(labels) agent_thought.tool_labels_str = json.dumps(labels)
if tool_invoke_meta is not None: if tool_invoke_meta is not None:
if isinstance(tool_invoke_meta, dict): if isinstance(tool_invoke_meta, dict):
@@ -404,7 +400,7 @@ class BaseAgentRunner(AppRunner):
except Exception: except Exception:
tool_invoke_meta = json.dumps(tool_invoke_meta) tool_invoke_meta = json.dumps(tool_invoke_meta)
updated_agent_thought.tool_meta_str = tool_invoke_meta agent_thought.tool_meta_str = tool_invoke_meta
db.session.commit() db.session.commit()
db.session.close() db.session.close()

View File

@@ -97,13 +97,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
message_file_ids: list[str] = [] message_file_ids: list[str] = []
agent_thought = self.create_agent_thought( agent_thought_id = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
) )
if iteration_step > 1: if iteration_step > 1:
self.queue_manager.publish( self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
) )
# recalc llm max tokens # recalc llm max tokens
@@ -133,7 +133,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish agent thought if it's first iteration # publish agent thought if it's first iteration
if iteration_step == 1: if iteration_step == 1:
self.queue_manager.publish( self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
) )
for chunk in react_chunks: for chunk in react_chunks:
@@ -168,7 +168,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
usage_dict["usage"] = LLMUsage.empty_usage() usage_dict["usage"] = LLMUsage.empty_usage()
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought_id=agent_thought_id,
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""), tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_invoke_meta={}, tool_invoke_meta={},
@@ -181,7 +181,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if not scratchpad.is_final(): if not scratchpad.is_final():
self.queue_manager.publish( self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
) )
if not scratchpad.action: if not scratchpad.action:
@@ -212,7 +212,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
scratchpad.agent_response = tool_invoke_response scratchpad.agent_response = tool_invoke_response
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought_id=agent_thought_id,
tool_name=scratchpad.action.action_name, tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought or "", thought=scratchpad.thought or "",
@@ -224,7 +224,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
) )
self.queue_manager.publish( self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
) )
# update prompt tool message # update prompt tool message
@@ -244,7 +244,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# save agent thought # save agent thought
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought_id=agent_thought_id,
tool_name="", tool_name="",
tool_input={}, tool_input={},
tool_invoke_meta={}, tool_invoke_meta={},

View File

@@ -80,7 +80,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
prompt_messages_tools = [] prompt_messages_tools = []
message_file_ids: list[str] = [] message_file_ids: list[str] = []
agent_thought = self.create_agent_thought( agent_thought_id = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
) )
@@ -114,7 +114,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for chunk in chunks: for chunk in chunks:
if is_first_chunk: if is_first_chunk:
self.queue_manager.publish( self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
) )
is_first_chunk = False is_first_chunk = False
# check if there is any tool call # check if there is any tool call
@@ -172,7 +172,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
result.message.content = "" result.message.content = ""
self.queue_manager.publish( self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
) )
yield LLMResultChunk( yield LLMResultChunk(
@@ -205,7 +205,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# save thought # save thought
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought_id=agent_thought_id,
tool_name=tool_call_names, tool_name=tool_call_names,
tool_input=tool_call_inputs, tool_input=tool_call_inputs,
thought=response, thought=response,
@@ -216,7 +216,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
llm_usage=current_llm_usage, llm_usage=current_llm_usage,
) )
self.queue_manager.publish( self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
) )
final_answer += response + "\n" final_answer += response + "\n"
@@ -276,7 +276,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if len(tool_responses) > 0: if len(tool_responses) > 0:
# save agent thought # save agent thought
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought_id=agent_thought_id,
tool_name="", tool_name="",
tool_input="", tool_input="",
thought="", thought="",
@@ -291,7 +291,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
messages_ids=message_file_ids, messages_ids=message_file_ids,
) )
self.queue_manager.publish( self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
) )
# update prompt tool # update prompt tool