From f3c8625fe27527386cfd0f8b8f3b538b93929f98 Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Thu, 3 Jul 2025 14:40:47 +0800 Subject: [PATCH] fix: The statistics page cannot display the tokens consumed by agent node (#21861) --- .../model_runtime/entities/llm_entities.py | 31 +++++++++++++++++++ api/core/workflow/nodes/tool/tool_node.py | 9 ++++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index e52b0eba5..ace2c1f77 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -53,6 +53,37 @@ class LLMUsage(ModelUsage): latency=0.0, ) + @classmethod + def from_metadata(cls, metadata: dict) -> "LLMUsage": + """ + Create LLMUsage instance from metadata dictionary with default values. + + Args: + metadata: Dictionary containing usage metadata + + Returns: + LLMUsage instance with values from metadata or defaults + """ + total_tokens = metadata.get("total_tokens", 0) + completion_tokens = metadata.get("completion_tokens", 0) + if total_tokens > 0 and completion_tokens == 0: + completion_tokens = total_tokens + + return cls( + prompt_tokens=metadata.get("prompt_tokens", 0), + completion_tokens=completion_tokens, + total_tokens=total_tokens, + prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))), + completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))), + total_price=Decimal(str(metadata.get("total_price", 0))), + currency=metadata.get("currency", "USD"), + prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))), + completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))), + prompt_price=Decimal(str(metadata.get("prompt_price", 0))), + completion_price=Decimal(str(metadata.get("completion_price", 0))), + latency=metadata.get("latency", 0.0), + ) + def plus(self, other: "LLMUsage") -> "LLMUsage": """ Add two LLMUsage instances together. diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 4d15d78a9..a4be02d86 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file import File, FileTransferMethod +from core.model_runtime.entities.llm_entities import LLMUsage from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.plugin import PluginInstaller from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter @@ -208,7 +209,7 @@ class ToolNode(BaseNode[ToolNodeData]): agent_logs: list[AgentLogEvent] = [] agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} - + llm_usage: LLMUsage | None = None variables: dict[str, Any] = {} for message in message_stream: @@ -276,9 +277,10 @@ class ToolNode(BaseNode[ToolNodeData]): elif message.type == ToolInvokeMessage.MessageType.JSON: assert isinstance(message.message, ToolInvokeMessage.JsonMessage) if self.node_type == NodeType.AGENT: - msg_metadata = message.message.json_object.pop("execution_metadata", {}) + msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) + llm_usage = LLMUsage.from_metadata(msg_metadata) agent_execution_metadata = { - key: value + WorkflowNodeExecutionMetadataKey(key): value for key, value in msg_metadata.items() if key in WorkflowNodeExecutionMetadataKey.__members__.values() } @@ -377,6 +379,7 @@ class ToolNode(BaseNode[ToolNodeData]): WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, }, inputs=parameters_for_log, + llm_usage=llm_usage, ) )