fix: Improve create_agent_thought and save_agent_thought Logic (#21263)

This commit is contained in:
Will
2025-07-27 11:06:37 +08:00
committed by GitHub
parent 665fcad655
commit 67a0751cf3
3 changed files with 37 additions and 41 deletions

View File

@@ -280,7 +280,7 @@ class BaseAgentRunner(AppRunner):
def create_agent_thought(
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought:
) -> str:
"""
Create agent thought
"""
@@ -313,16 +313,15 @@ class BaseAgentRunner(AppRunner):
db.session.add(thought)
db.session.commit()
db.session.refresh(thought)
agent_thought_id = str(thought.id)
self.agent_thought_count += 1
db.session.close()
self.agent_thought_count += 1
return thought
return agent_thought_id
def save_agent_thought(
self,
agent_thought: MessageAgentThought,
agent_thought_id: str,
tool_name: str | None,
tool_input: Union[str, dict, None],
thought: str | None,
@@ -335,12 +334,9 @@ class BaseAgentRunner(AppRunner):
"""
Save agent thought
"""
updated_agent_thought = (
db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first()
)
if not updated_agent_thought:
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
if not agent_thought:
raise ValueError("agent thought not found")
agent_thought = updated_agent_thought
if thought:
agent_thought.thought += thought
@@ -355,7 +351,7 @@ class BaseAgentRunner(AppRunner):
except Exception:
tool_input = json.dumps(tool_input)
updated_agent_thought.tool_input = tool_input
agent_thought.tool_input = tool_input
if observation:
if isinstance(observation, dict):
@@ -364,27 +360,27 @@ class BaseAgentRunner(AppRunner):
except Exception:
observation = json.dumps(observation)
updated_agent_thought.observation = observation
agent_thought.observation = observation
if answer:
agent_thought.answer = answer
if messages_ids is not None and len(messages_ids) > 0:
updated_agent_thought.message_files = json.dumps(messages_ids)
agent_thought.message_files = json.dumps(messages_ids)
if llm_usage:
updated_agent_thought.message_token = llm_usage.prompt_tokens
updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
updated_agent_thought.answer_token = llm_usage.completion_tokens
updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
updated_agent_thought.tokens = llm_usage.total_tokens
updated_agent_thought.total_price = llm_usage.total_price
agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit
agent_thought.message_unit_price = llm_usage.prompt_unit_price
agent_thought.answer_token = llm_usage.completion_tokens
agent_thought.answer_price_unit = llm_usage.completion_price_unit
agent_thought.answer_unit_price = llm_usage.completion_unit_price
agent_thought.tokens = llm_usage.total_tokens
agent_thought.total_price = llm_usage.total_price
# check if tool labels is not empty
labels = updated_agent_thought.tool_labels or {}
tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
labels = agent_thought.tool_labels or {}
tools = agent_thought.tool.split(";") if agent_thought.tool else []
for tool in tools:
if not tool:
continue
@@ -395,7 +391,7 @@ class BaseAgentRunner(AppRunner):
else:
labels[tool] = {"en_US": tool, "zh_Hans": tool}
updated_agent_thought.tool_labels_str = json.dumps(labels)
agent_thought.tool_labels_str = json.dumps(labels)
if tool_invoke_meta is not None:
if isinstance(tool_invoke_meta, dict):
@@ -404,7 +400,7 @@ class BaseAgentRunner(AppRunner):
except Exception:
tool_invoke_meta = json.dumps(tool_invoke_meta)
updated_agent_thought.tool_meta_str = tool_invoke_meta
agent_thought.tool_meta_str = tool_invoke_meta
db.session.commit()
db.session.close()