fix: correct agent node token counting to properly separate prompt and completion tokens (#24368)
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, TypedDict, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -20,6 +20,26 @@ class LLMMode(StrEnum):
|
|||||||
CHAT = "chat"
|
CHAT = "chat"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMUsageMetadata(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
TypedDict for LLM usage metadata.
|
||||||
|
All fields are optional.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
prompt_unit_price: Union[float, str]
|
||||||
|
completion_unit_price: Union[float, str]
|
||||||
|
total_price: Union[float, str]
|
||||||
|
currency: str
|
||||||
|
prompt_price_unit: Union[float, str]
|
||||||
|
completion_price_unit: Union[float, str]
|
||||||
|
prompt_price: Union[float, str]
|
||||||
|
completion_price: Union[float, str]
|
||||||
|
latency: float
|
||||||
|
|
||||||
|
|
||||||
class LLMUsage(ModelUsage):
|
class LLMUsage(ModelUsage):
|
||||||
"""
|
"""
|
||||||
Model class for llm usage.
|
Model class for llm usage.
|
||||||
@@ -56,23 +76,27 @@ class LLMUsage(ModelUsage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_metadata(cls, metadata: dict) -> LLMUsage:
|
def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage:
|
||||||
"""
|
"""
|
||||||
Create LLMUsage instance from metadata dictionary with default values.
|
Create LLMUsage instance from metadata dictionary with default values.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
metadata: Dictionary containing usage metadata
|
metadata: TypedDict containing usage metadata
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LLMUsage instance with values from metadata or defaults
|
LLMUsage instance with values from metadata or defaults
|
||||||
"""
|
"""
|
||||||
total_tokens = metadata.get("total_tokens", 0)
|
prompt_tokens = metadata.get("prompt_tokens", 0)
|
||||||
completion_tokens = metadata.get("completion_tokens", 0)
|
completion_tokens = metadata.get("completion_tokens", 0)
|
||||||
if total_tokens > 0 and completion_tokens == 0:
|
total_tokens = metadata.get("total_tokens", 0)
|
||||||
completion_tokens = total_tokens
|
|
||||||
|
# If total_tokens is not provided but prompt and completion tokens are,
|
||||||
|
# calculate total_tokens
|
||||||
|
if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0):
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
prompt_tokens=metadata.get("prompt_tokens", 0),
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
|
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
|
||||||
|
@@ -13,7 +13,7 @@ from core.agent.strategy.plugin import PluginAgentStrategy
|
|||||||
from core.file import File, FileTransferMethod
|
from core.file import File, FileTransferMethod
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.entities.request import InvokeCredentials
|
from core.plugin.entities.request import InvokeCredentials
|
||||||
@@ -559,7 +559,7 @@ class AgentNode(BaseNode):
|
|||||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||||
if node_type == NodeType.AGENT:
|
if node_type == NodeType.AGENT:
|
||||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||||
llm_usage = LLMUsage.from_metadata(msg_metadata)
|
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||||
agent_execution_metadata = {
|
agent_execution_metadata = {
|
||||||
WorkflowNodeExecutionMetadataKey(key): value
|
WorkflowNodeExecutionMetadataKey(key): value
|
||||||
for key, value in msg_metadata.items()
|
for key, value in msg_metadata.items()
|
||||||
|
@@ -0,0 +1,148 @@
|
|||||||
|
"""Tests for LLMUsage entity."""
|
||||||
|
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMUsage:
|
||||||
|
"""Test cases for LLMUsage class."""
|
||||||
|
|
||||||
|
def test_from_metadata_with_all_tokens(self):
|
||||||
|
"""Test from_metadata when all token types are provided."""
|
||||||
|
metadata: LLMUsageMetadata = {
|
||||||
|
"prompt_tokens": 100,
|
||||||
|
"completion_tokens": 50,
|
||||||
|
"total_tokens": 150,
|
||||||
|
"prompt_unit_price": 0.001,
|
||||||
|
"completion_unit_price": 0.002,
|
||||||
|
"total_price": 0.2,
|
||||||
|
"currency": "USD",
|
||||||
|
"latency": 1.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = LLMUsage.from_metadata(metadata)
|
||||||
|
|
||||||
|
assert usage.prompt_tokens == 100
|
||||||
|
assert usage.completion_tokens == 50
|
||||||
|
assert usage.total_tokens == 150
|
||||||
|
assert usage.prompt_unit_price == Decimal("0.001")
|
||||||
|
assert usage.completion_unit_price == Decimal("0.002")
|
||||||
|
assert usage.total_price == Decimal("0.2")
|
||||||
|
assert usage.currency == "USD"
|
||||||
|
assert usage.latency == 1.5
|
||||||
|
|
||||||
|
def test_from_metadata_with_prompt_tokens_only(self):
|
||||||
|
"""Test from_metadata when only prompt_tokens is provided."""
|
||||||
|
metadata: LLMUsageMetadata = {
|
||||||
|
"prompt_tokens": 100,
|
||||||
|
"total_tokens": 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = LLMUsage.from_metadata(metadata)
|
||||||
|
|
||||||
|
assert usage.prompt_tokens == 100
|
||||||
|
assert usage.completion_tokens == 0
|
||||||
|
assert usage.total_tokens == 100
|
||||||
|
|
||||||
|
def test_from_metadata_with_completion_tokens_only(self):
|
||||||
|
"""Test from_metadata when only completion_tokens is provided."""
|
||||||
|
metadata: LLMUsageMetadata = {
|
||||||
|
"completion_tokens": 50,
|
||||||
|
"total_tokens": 50,
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = LLMUsage.from_metadata(metadata)
|
||||||
|
|
||||||
|
assert usage.prompt_tokens == 0
|
||||||
|
assert usage.completion_tokens == 50
|
||||||
|
assert usage.total_tokens == 50
|
||||||
|
|
||||||
|
def test_from_metadata_calculates_total_when_missing(self):
|
||||||
|
"""Test from_metadata calculates total_tokens when not provided."""
|
||||||
|
metadata: LLMUsageMetadata = {
|
||||||
|
"prompt_tokens": 100,
|
||||||
|
"completion_tokens": 50,
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = LLMUsage.from_metadata(metadata)
|
||||||
|
|
||||||
|
assert usage.prompt_tokens == 100
|
||||||
|
assert usage.completion_tokens == 50
|
||||||
|
assert usage.total_tokens == 150 # Should be calculated
|
||||||
|
|
||||||
|
def test_from_metadata_with_total_but_no_completion(self):
|
||||||
|
"""
|
||||||
|
Test from_metadata when total_tokens is provided but completion_tokens is 0.
|
||||||
|
This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens.
|
||||||
|
"""
|
||||||
|
metadata: LLMUsageMetadata = {
|
||||||
|
"prompt_tokens": 479,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 521,
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = LLMUsage.from_metadata(metadata)
|
||||||
|
|
||||||
|
# This is the key fix - prompt tokens should remain as prompt tokens
|
||||||
|
assert usage.prompt_tokens == 479
|
||||||
|
assert usage.completion_tokens == 0
|
||||||
|
assert usage.total_tokens == 521
|
||||||
|
|
||||||
|
def test_from_metadata_with_empty_metadata(self):
|
||||||
|
"""Test from_metadata with empty metadata."""
|
||||||
|
metadata: LLMUsageMetadata = {}
|
||||||
|
|
||||||
|
usage = LLMUsage.from_metadata(metadata)
|
||||||
|
|
||||||
|
assert usage.prompt_tokens == 0
|
||||||
|
assert usage.completion_tokens == 0
|
||||||
|
assert usage.total_tokens == 0
|
||||||
|
assert usage.currency == "USD"
|
||||||
|
assert usage.latency == 0.0
|
||||||
|
|
||||||
|
def test_from_metadata_preserves_zero_completion_tokens(self):
|
||||||
|
"""
|
||||||
|
Test that zero completion_tokens are preserved when explicitly set.
|
||||||
|
This is important for agent nodes that only use prompt tokens.
|
||||||
|
"""
|
||||||
|
metadata: LLMUsageMetadata = {
|
||||||
|
"prompt_tokens": 1000,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 1000,
|
||||||
|
"prompt_unit_price": 0.15,
|
||||||
|
"completion_unit_price": 0.60,
|
||||||
|
"prompt_price": 0.00015,
|
||||||
|
"completion_price": 0,
|
||||||
|
"total_price": 0.00015,
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = LLMUsage.from_metadata(metadata)
|
||||||
|
|
||||||
|
assert usage.prompt_tokens == 1000
|
||||||
|
assert usage.completion_tokens == 0
|
||||||
|
assert usage.total_tokens == 1000
|
||||||
|
assert usage.prompt_price == Decimal("0.00015")
|
||||||
|
assert usage.completion_price == Decimal(0)
|
||||||
|
assert usage.total_price == Decimal("0.00015")
|
||||||
|
|
||||||
|
def test_from_metadata_with_decimal_values(self):
|
||||||
|
"""Test from_metadata handles decimal values correctly."""
|
||||||
|
metadata: LLMUsageMetadata = {
|
||||||
|
"prompt_tokens": 100,
|
||||||
|
"completion_tokens": 50,
|
||||||
|
"total_tokens": 150,
|
||||||
|
"prompt_unit_price": "0.001",
|
||||||
|
"completion_unit_price": "0.002",
|
||||||
|
"prompt_price": "0.1",
|
||||||
|
"completion_price": "0.1",
|
||||||
|
"total_price": "0.2",
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = LLMUsage.from_metadata(metadata)
|
||||||
|
|
||||||
|
assert usage.prompt_unit_price == Decimal("0.001")
|
||||||
|
assert usage.completion_unit_price == Decimal("0.002")
|
||||||
|
assert usage.prompt_price == Decimal("0.1")
|
||||||
|
assert usage.completion_price == Decimal("0.1")
|
||||||
|
assert usage.total_price == Decimal("0.2")
|
Reference in New Issue
Block a user