fix: Improve create_agent_thought and save_agent_thought Logic (#21263)
This commit is contained in:
@@ -280,7 +280,7 @@ class BaseAgentRunner(AppRunner):
|
|||||||
|
|
||||||
def create_agent_thought(
|
def create_agent_thought(
|
||||||
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
|
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
|
||||||
) -> MessageAgentThought:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Create agent thought
|
Create agent thought
|
||||||
"""
|
"""
|
||||||
@@ -313,16 +313,15 @@ class BaseAgentRunner(AppRunner):
|
|||||||
|
|
||||||
db.session.add(thought)
|
db.session.add(thought)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
db.session.refresh(thought)
|
agent_thought_id = str(thought.id)
|
||||||
|
self.agent_thought_count += 1
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
self.agent_thought_count += 1
|
return agent_thought_id
|
||||||
|
|
||||||
return thought
|
|
||||||
|
|
||||||
def save_agent_thought(
|
def save_agent_thought(
|
||||||
self,
|
self,
|
||||||
agent_thought: MessageAgentThought,
|
agent_thought_id: str,
|
||||||
tool_name: str | None,
|
tool_name: str | None,
|
||||||
tool_input: Union[str, dict, None],
|
tool_input: Union[str, dict, None],
|
||||||
thought: str | None,
|
thought: str | None,
|
||||||
@@ -335,12 +334,9 @@ class BaseAgentRunner(AppRunner):
|
|||||||
"""
|
"""
|
||||||
Save agent thought
|
Save agent thought
|
||||||
"""
|
"""
|
||||||
updated_agent_thought = (
|
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
|
||||||
db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first()
|
if not agent_thought:
|
||||||
)
|
|
||||||
if not updated_agent_thought:
|
|
||||||
raise ValueError("agent thought not found")
|
raise ValueError("agent thought not found")
|
||||||
agent_thought = updated_agent_thought
|
|
||||||
|
|
||||||
if thought:
|
if thought:
|
||||||
agent_thought.thought += thought
|
agent_thought.thought += thought
|
||||||
@@ -355,7 +351,7 @@ class BaseAgentRunner(AppRunner):
|
|||||||
except Exception:
|
except Exception:
|
||||||
tool_input = json.dumps(tool_input)
|
tool_input = json.dumps(tool_input)
|
||||||
|
|
||||||
updated_agent_thought.tool_input = tool_input
|
agent_thought.tool_input = tool_input
|
||||||
|
|
||||||
if observation:
|
if observation:
|
||||||
if isinstance(observation, dict):
|
if isinstance(observation, dict):
|
||||||
@@ -364,27 +360,27 @@ class BaseAgentRunner(AppRunner):
|
|||||||
except Exception:
|
except Exception:
|
||||||
observation = json.dumps(observation)
|
observation = json.dumps(observation)
|
||||||
|
|
||||||
updated_agent_thought.observation = observation
|
agent_thought.observation = observation
|
||||||
|
|
||||||
if answer:
|
if answer:
|
||||||
agent_thought.answer = answer
|
agent_thought.answer = answer
|
||||||
|
|
||||||
if messages_ids is not None and len(messages_ids) > 0:
|
if messages_ids is not None and len(messages_ids) > 0:
|
||||||
updated_agent_thought.message_files = json.dumps(messages_ids)
|
agent_thought.message_files = json.dumps(messages_ids)
|
||||||
|
|
||||||
if llm_usage:
|
if llm_usage:
|
||||||
updated_agent_thought.message_token = llm_usage.prompt_tokens
|
agent_thought.message_token = llm_usage.prompt_tokens
|
||||||
updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||||
updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||||
updated_agent_thought.answer_token = llm_usage.completion_tokens
|
agent_thought.answer_token = llm_usage.completion_tokens
|
||||||
updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||||
updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||||
updated_agent_thought.tokens = llm_usage.total_tokens
|
agent_thought.tokens = llm_usage.total_tokens
|
||||||
updated_agent_thought.total_price = llm_usage.total_price
|
agent_thought.total_price = llm_usage.total_price
|
||||||
|
|
||||||
# check if tool labels is not empty
|
# check if tool labels is not empty
|
||||||
labels = updated_agent_thought.tool_labels or {}
|
labels = agent_thought.tool_labels or {}
|
||||||
tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
|
tools = agent_thought.tool.split(";") if agent_thought.tool else []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if not tool:
|
if not tool:
|
||||||
continue
|
continue
|
||||||
@@ -395,7 +391,7 @@ class BaseAgentRunner(AppRunner):
|
|||||||
else:
|
else:
|
||||||
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
||||||
|
|
||||||
updated_agent_thought.tool_labels_str = json.dumps(labels)
|
agent_thought.tool_labels_str = json.dumps(labels)
|
||||||
|
|
||||||
if tool_invoke_meta is not None:
|
if tool_invoke_meta is not None:
|
||||||
if isinstance(tool_invoke_meta, dict):
|
if isinstance(tool_invoke_meta, dict):
|
||||||
@@ -404,7 +400,7 @@ class BaseAgentRunner(AppRunner):
|
|||||||
except Exception:
|
except Exception:
|
||||||
tool_invoke_meta = json.dumps(tool_invoke_meta)
|
tool_invoke_meta = json.dumps(tool_invoke_meta)
|
||||||
|
|
||||||
updated_agent_thought.tool_meta_str = tool_invoke_meta
|
agent_thought.tool_meta_str = tool_invoke_meta
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
@@ -97,13 +97,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
|
|
||||||
message_file_ids: list[str] = []
|
message_file_ids: list[str] = []
|
||||||
|
|
||||||
agent_thought = self.create_agent_thought(
|
agent_thought_id = self.create_agent_thought(
|
||||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if iteration_step > 1:
|
if iteration_step > 1:
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
# recalc llm max tokens
|
# recalc llm max tokens
|
||||||
@@ -133,7 +133,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
# publish agent thought if it's first iteration
|
# publish agent thought if it's first iteration
|
||||||
if iteration_step == 1:
|
if iteration_step == 1:
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in react_chunks:
|
for chunk in react_chunks:
|
||||||
@@ -168,7 +168,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
usage_dict["usage"] = LLMUsage.empty_usage()
|
usage_dict["usage"] = LLMUsage.empty_usage()
|
||||||
|
|
||||||
self.save_agent_thought(
|
self.save_agent_thought(
|
||||||
agent_thought=agent_thought,
|
agent_thought_id=agent_thought_id,
|
||||||
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
||||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||||
tool_invoke_meta={},
|
tool_invoke_meta={},
|
||||||
@@ -181,7 +181,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
|
|
||||||
if not scratchpad.is_final():
|
if not scratchpad.is_final():
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
if not scratchpad.action:
|
if not scratchpad.action:
|
||||||
@@ -212,7 +212,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
scratchpad.agent_response = tool_invoke_response
|
scratchpad.agent_response = tool_invoke_response
|
||||||
|
|
||||||
self.save_agent_thought(
|
self.save_agent_thought(
|
||||||
agent_thought=agent_thought,
|
agent_thought_id=agent_thought_id,
|
||||||
tool_name=scratchpad.action.action_name,
|
tool_name=scratchpad.action.action_name,
|
||||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||||
thought=scratchpad.thought or "",
|
thought=scratchpad.thought or "",
|
||||||
@@ -224,7 +224,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
# update prompt tool message
|
# update prompt tool message
|
||||||
@@ -244,7 +244,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
|
|
||||||
# save agent thought
|
# save agent thought
|
||||||
self.save_agent_thought(
|
self.save_agent_thought(
|
||||||
agent_thought=agent_thought,
|
agent_thought_id=agent_thought_id,
|
||||||
tool_name="",
|
tool_name="",
|
||||||
tool_input={},
|
tool_input={},
|
||||||
tool_invoke_meta={},
|
tool_invoke_meta={},
|
||||||
|
@@ -80,7 +80,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
prompt_messages_tools = []
|
prompt_messages_tools = []
|
||||||
|
|
||||||
message_file_ids: list[str] = []
|
message_file_ids: list[str] = []
|
||||||
agent_thought = self.create_agent_thought(
|
agent_thought_id = self.create_agent_thought(
|
||||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -114,7 +114,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if is_first_chunk:
|
if is_first_chunk:
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
is_first_chunk = False
|
is_first_chunk = False
|
||||||
# check if there is any tool call
|
# check if there is any tool call
|
||||||
@@ -172,7 +172,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
result.message.content = ""
|
result.message.content = ""
|
||||||
|
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
@@ -205,7 +205,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
|
|
||||||
# save thought
|
# save thought
|
||||||
self.save_agent_thought(
|
self.save_agent_thought(
|
||||||
agent_thought=agent_thought,
|
agent_thought_id=agent_thought_id,
|
||||||
tool_name=tool_call_names,
|
tool_name=tool_call_names,
|
||||||
tool_input=tool_call_inputs,
|
tool_input=tool_call_inputs,
|
||||||
thought=response,
|
thought=response,
|
||||||
@@ -216,7 +216,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
llm_usage=current_llm_usage,
|
llm_usage=current_llm_usage,
|
||||||
)
|
)
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
final_answer += response + "\n"
|
final_answer += response + "\n"
|
||||||
@@ -276,7 +276,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
if len(tool_responses) > 0:
|
if len(tool_responses) > 0:
|
||||||
# save agent thought
|
# save agent thought
|
||||||
self.save_agent_thought(
|
self.save_agent_thought(
|
||||||
agent_thought=agent_thought,
|
agent_thought_id=agent_thought_id,
|
||||||
tool_name="",
|
tool_name="",
|
||||||
tool_input="",
|
tool_input="",
|
||||||
thought="",
|
thought="",
|
||||||
@@ -291,7 +291,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
messages_ids=message_file_ids,
|
messages_ids=message_file_ids,
|
||||||
)
|
)
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
# update prompt tool
|
# update prompt tool
|
||||||
|
Reference in New Issue
Block a user