feat: record price unit in messages (#919)
This commit is contained in:
@@ -197,7 +197,7 @@ class BaseLLM(BaseProviderModel):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def calc_tokens_price(self, tokens:int, message_type: MessageType):
|
||||
def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
calc tokens total price.
|
||||
|
||||
@@ -209,14 +209,14 @@ class BaseLLM(BaseProviderModel):
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
unit = self.price_config['unit']
|
||||
unit = self.get_price_unit(message_type)
|
||||
|
||||
total_price = tokens * unit_price * unit
|
||||
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
||||
return total_price
|
||||
|
||||
def get_tokens_unit_price(self, message_type: MessageType):
|
||||
def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
get token price.
|
||||
|
||||
@@ -231,7 +231,23 @@ class BaseLLM(BaseProviderModel):
|
||||
logging.debug(f"unit_price={unit_price}")
|
||||
return unit_price
|
||||
|
||||
def get_currency(self):
|
||||
def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
get price unit.
|
||||
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.000001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
price_unit = self.price_config['unit']
|
||||
else:
|
||||
price_unit = self.price_config['unit']
|
||||
|
||||
price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"price_unit={price_unit}")
|
||||
return price_unit
|
||||
|
||||
def get_currency(self) -> str:
|
||||
"""
|
||||
get token currency.
|
||||
|
||||
|
Reference in New Issue
Block a user