fix: correct agent node token counting to properly separate prompt and completion tokens (#24368)

This commit is contained in:
-LAN-
2025-08-23 11:00:14 +08:00
committed by GitHub
parent 0a2111f33d
commit 2e47558f4b
3 changed files with 181 additions and 9 deletions

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from decimal import Decimal
from enum import StrEnum
from typing import Any, Optional
from typing import Any, Optional, TypedDict, Union
from pydantic import BaseModel, Field
@@ -20,6 +20,26 @@ class LLMMode(StrEnum):
CHAT = "chat"
class LLMUsageMetadata(TypedDict, total=False):
"""
TypedDict for LLM usage metadata.
All fields are optional.
"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_unit_price: Union[float, str]
completion_unit_price: Union[float, str]
total_price: Union[float, str]
currency: str
prompt_price_unit: Union[float, str]
completion_price_unit: Union[float, str]
prompt_price: Union[float, str]
completion_price: Union[float, str]
latency: float
class LLMUsage(ModelUsage):
"""
Model class for llm usage.
@@ -56,23 +76,27 @@ class LLMUsage(ModelUsage):
)
@classmethod
def from_metadata(cls, metadata: dict) -> LLMUsage:
def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage:
"""
Create LLMUsage instance from metadata dictionary with default values.
Args:
metadata: Dictionary containing usage metadata
metadata: TypedDict containing usage metadata
Returns:
LLMUsage instance with values from metadata or defaults
"""
total_tokens = metadata.get("total_tokens", 0)
prompt_tokens = metadata.get("prompt_tokens", 0)
completion_tokens = metadata.get("completion_tokens", 0)
if total_tokens > 0 and completion_tokens == 0:
completion_tokens = total_tokens
total_tokens = metadata.get("total_tokens", 0)
# If total_tokens is not provided but prompt and completion tokens are,
# calculate total_tokens
if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0):
total_tokens = prompt_tokens + completion_tokens
return cls(
prompt_tokens=metadata.get("prompt_tokens", 0),
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),

View File

@@ -13,7 +13,7 @@ from core.agent.strategy.plugin import PluginAgentStrategy
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.request import InvokeCredentials
@@ -559,7 +559,7 @@ class AgentNode(BaseNode):
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(msg_metadata)
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()

View File

@@ -0,0 +1,148 @@
"""Tests for LLMUsage entity."""
from decimal import Decimal
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
class TestLLMUsage:
"""Test cases for LLMUsage class."""
def test_from_metadata_with_all_tokens(self):
"""Test from_metadata when all token types are provided."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"prompt_unit_price": 0.001,
"completion_unit_price": 0.002,
"total_price": 0.2,
"currency": "USD",
"latency": 1.5,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150
assert usage.prompt_unit_price == Decimal("0.001")
assert usage.completion_unit_price == Decimal("0.002")
assert usage.total_price == Decimal("0.2")
assert usage.currency == "USD"
assert usage.latency == 1.5
def test_from_metadata_with_prompt_tokens_only(self):
"""Test from_metadata when only prompt_tokens is provided."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"total_tokens": 100,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 0
assert usage.total_tokens == 100
def test_from_metadata_with_completion_tokens_only(self):
"""Test from_metadata when only completion_tokens is provided."""
metadata: LLMUsageMetadata = {
"completion_tokens": 50,
"total_tokens": 50,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 0
assert usage.completion_tokens == 50
assert usage.total_tokens == 50
def test_from_metadata_calculates_total_when_missing(self):
"""Test from_metadata calculates total_tokens when not provided."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"completion_tokens": 50,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150 # Should be calculated
def test_from_metadata_with_total_but_no_completion(self):
"""
Test from_metadata when total_tokens is provided but completion_tokens is 0.
This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens.
"""
metadata: LLMUsageMetadata = {
"prompt_tokens": 479,
"completion_tokens": 0,
"total_tokens": 521,
}
usage = LLMUsage.from_metadata(metadata)
# This is the key fix - prompt tokens should remain as prompt tokens
assert usage.prompt_tokens == 479
assert usage.completion_tokens == 0
assert usage.total_tokens == 521
def test_from_metadata_with_empty_metadata(self):
"""Test from_metadata with empty metadata."""
metadata: LLMUsageMetadata = {}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 0
assert usage.completion_tokens == 0
assert usage.total_tokens == 0
assert usage.currency == "USD"
assert usage.latency == 0.0
def test_from_metadata_preserves_zero_completion_tokens(self):
"""
Test that zero completion_tokens are preserved when explicitly set.
This is important for agent nodes that only use prompt tokens.
"""
metadata: LLMUsageMetadata = {
"prompt_tokens": 1000,
"completion_tokens": 0,
"total_tokens": 1000,
"prompt_unit_price": 0.15,
"completion_unit_price": 0.60,
"prompt_price": 0.00015,
"completion_price": 0,
"total_price": 0.00015,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 1000
assert usage.completion_tokens == 0
assert usage.total_tokens == 1000
assert usage.prompt_price == Decimal("0.00015")
assert usage.completion_price == Decimal(0)
assert usage.total_price == Decimal("0.00015")
def test_from_metadata_with_decimal_values(self):
"""Test from_metadata handles decimal values correctly."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"prompt_unit_price": "0.001",
"completion_unit_price": "0.002",
"prompt_price": "0.1",
"completion_price": "0.1",
"total_price": "0.2",
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_unit_price == Decimal("0.001")
assert usage.completion_unit_price == Decimal("0.002")
assert usage.prompt_price == Decimal("0.1")
assert usage.completion_price == Decimal("0.1")
assert usage.total_price == Decimal("0.2")