From 2e47558f4b9638bec50f1120fe259e9b6266a20b Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 23 Aug 2025 11:00:14 +0800 Subject: [PATCH] fix: correct agent node token counting to properly separate prompt and completion tokens (#24368) --- .../model_runtime/entities/llm_entities.py | 38 ++++- api/core/workflow/nodes/agent/agent_node.py | 4 +- .../entities/test_llm_entities.py | 148 ++++++++++++++++++ 3 files changed, 181 insertions(+), 9 deletions(-) create mode 100644 api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index 0e1277bc8..dc6032e40 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from decimal import Decimal from enum import StrEnum -from typing import Any, Optional +from typing import Any, Optional, TypedDict, Union from pydantic import BaseModel, Field @@ -20,6 +20,26 @@ class LLMMode(StrEnum): CHAT = "chat" +class LLMUsageMetadata(TypedDict, total=False): + """ + TypedDict for LLM usage metadata. + All fields are optional. + """ + + prompt_tokens: int + completion_tokens: int + total_tokens: int + prompt_unit_price: Union[float, str] + completion_unit_price: Union[float, str] + total_price: Union[float, str] + currency: str + prompt_price_unit: Union[float, str] + completion_price_unit: Union[float, str] + prompt_price: Union[float, str] + completion_price: Union[float, str] + latency: float + + class LLMUsage(ModelUsage): """ Model class for llm usage. @@ -56,23 +76,27 @@ class LLMUsage(ModelUsage): ) @classmethod - def from_metadata(cls, metadata: dict) -> LLMUsage: + def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage: """ Create LLMUsage instance from metadata dictionary with default values. Args: - metadata: Dictionary containing usage metadata + metadata: TypedDict containing usage metadata Returns: LLMUsage instance with values from metadata or defaults """ - total_tokens = metadata.get("total_tokens", 0) + prompt_tokens = metadata.get("prompt_tokens", 0) completion_tokens = metadata.get("completion_tokens", 0) - if total_tokens > 0 and completion_tokens == 0: - completion_tokens = total_tokens + total_tokens = metadata.get("total_tokens", 0) + + # If total_tokens is not provided but prompt and completion tokens are, + # calculate total_tokens + if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0): + total_tokens = prompt_tokens + completion_tokens return cls( - prompt_tokens=metadata.get("prompt_tokens", 0), + prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))), diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 419f2ca55..144f036aa 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -13,7 +13,7 @@ from core.agent.strategy.plugin import PluginAgentStrategy from core.file import File, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.request import InvokeCredentials @@ -559,7 +559,7 @@ class AgentNode(BaseNode): assert isinstance(message.message, ToolInvokeMessage.JsonMessage) if node_type == NodeType.AGENT: msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) - llm_usage = LLMUsage.from_metadata(msg_metadata) + llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata)) agent_execution_metadata = { WorkflowNodeExecutionMetadataKey(key): value for key, value in msg_metadata.items() diff --git a/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py b/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py new file mode 100644 index 000000000..c10f7b89c --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py @@ -0,0 +1,148 @@ +"""Tests for LLMUsage entity.""" + +from decimal import Decimal + +from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata + + +class TestLLMUsage: + """Test cases for LLMUsage class.""" + + def test_from_metadata_with_all_tokens(self): + """Test from_metadata when all token types are provided.""" + metadata: LLMUsageMetadata = { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "prompt_unit_price": 0.001, + "completion_unit_price": 0.002, + "total_price": 0.2, + "currency": "USD", + "latency": 1.5, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 150 + assert usage.prompt_unit_price == Decimal("0.001") + assert usage.completion_unit_price == Decimal("0.002") + assert usage.total_price == Decimal("0.2") + assert usage.currency == "USD" + assert usage.latency == 1.5 + + def test_from_metadata_with_prompt_tokens_only(self): + """Test from_metadata when only prompt_tokens is provided.""" + metadata: LLMUsageMetadata = { + "prompt_tokens": 100, + "total_tokens": 100, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 100 + + def test_from_metadata_with_completion_tokens_only(self): + """Test from_metadata when only completion_tokens is provided.""" + metadata: LLMUsageMetadata = { + "completion_tokens": 50, + "total_tokens": 50, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 0 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 50 + + def test_from_metadata_calculates_total_when_missing(self): + """Test from_metadata calculates total_tokens when not provided.""" + metadata: LLMUsageMetadata = { + "prompt_tokens": 100, + "completion_tokens": 50, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 150 # Should be calculated + + def test_from_metadata_with_total_but_no_completion(self): + """ + Test from_metadata when total_tokens is provided but completion_tokens is 0. + This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens. + """ + metadata: LLMUsageMetadata = { + "prompt_tokens": 479, + "completion_tokens": 0, + "total_tokens": 521, + } + + usage = LLMUsage.from_metadata(metadata) + + # This is the key fix - prompt tokens should remain as prompt tokens + assert usage.prompt_tokens == 479 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 521 + + def test_from_metadata_with_empty_metadata(self): + """Test from_metadata with empty metadata.""" + metadata: LLMUsageMetadata = {} + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 0 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 0 + assert usage.currency == "USD" + assert usage.latency == 0.0 + + def test_from_metadata_preserves_zero_completion_tokens(self): + """ + Test that zero completion_tokens are preserved when explicitly set. + This is important for agent nodes that only use prompt tokens. + """ + metadata: LLMUsageMetadata = { + "prompt_tokens": 1000, + "completion_tokens": 0, + "total_tokens": 1000, + "prompt_unit_price": 0.15, + "completion_unit_price": 0.60, + "prompt_price": 0.00015, + "completion_price": 0, + "total_price": 0.00015, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 1000 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 1000 + assert usage.prompt_price == Decimal("0.00015") + assert usage.completion_price == Decimal(0) + assert usage.total_price == Decimal("0.00015") + + def test_from_metadata_with_decimal_values(self): + """Test from_metadata handles decimal values correctly.""" + metadata: LLMUsageMetadata = { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "prompt_unit_price": "0.001", + "completion_unit_price": "0.002", + "prompt_price": "0.1", + "completion_price": "0.1", + "total_price": "0.2", + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_unit_price == Decimal("0.001") + assert usage.completion_unit_price == Decimal("0.002") + assert usage.prompt_price == Decimal("0.1") + assert usage.completion_price == Decimal("0.1") + assert usage.total_price == Decimal("0.2")