fix: correct agent node token counting to properly separate prompt and completion tokens (#24368)
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, TypedDict, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -20,6 +20,26 @@ class LLMMode(StrEnum):
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
class LLMUsageMetadata(TypedDict, total=False):
|
||||
"""
|
||||
TypedDict for LLM usage metadata.
|
||||
All fields are optional.
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_unit_price: Union[float, str]
|
||||
completion_unit_price: Union[float, str]
|
||||
total_price: Union[float, str]
|
||||
currency: str
|
||||
prompt_price_unit: Union[float, str]
|
||||
completion_price_unit: Union[float, str]
|
||||
prompt_price: Union[float, str]
|
||||
completion_price: Union[float, str]
|
||||
latency: float
|
||||
|
||||
|
||||
class LLMUsage(ModelUsage):
|
||||
"""
|
||||
Model class for llm usage.
|
||||
@@ -56,23 +76,27 @@ class LLMUsage(ModelUsage):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_metadata(cls, metadata: dict) -> LLMUsage:
|
||||
def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage:
|
||||
"""
|
||||
Create LLMUsage instance from metadata dictionary with default values.
|
||||
|
||||
Args:
|
||||
metadata: Dictionary containing usage metadata
|
||||
metadata: TypedDict containing usage metadata
|
||||
|
||||
Returns:
|
||||
LLMUsage instance with values from metadata or defaults
|
||||
"""
|
||||
total_tokens = metadata.get("total_tokens", 0)
|
||||
prompt_tokens = metadata.get("prompt_tokens", 0)
|
||||
completion_tokens = metadata.get("completion_tokens", 0)
|
||||
if total_tokens > 0 and completion_tokens == 0:
|
||||
completion_tokens = total_tokens
|
||||
total_tokens = metadata.get("total_tokens", 0)
|
||||
|
||||
# If total_tokens is not provided but prompt and completion tokens are,
|
||||
# calculate total_tokens
|
||||
if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0):
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
return cls(
|
||||
prompt_tokens=metadata.get("prompt_tokens", 0),
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
|
||||
|
@@ -13,7 +13,7 @@ from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
@@ -559,7 +559,7 @@ class AgentNode(BaseNode):
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(msg_metadata)
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
|
Reference in New Issue
Block a user