feat: record price unit in messages (#919)

This commit is contained in:
takatost
2023-08-19 18:51:40 +08:00
committed by GitHub
parent 920fb6d0e1
commit 0a0d63457d
5 changed files with 88 additions and 4 deletions

View File

@@ -197,7 +197,7 @@ class BaseLLM(BaseProviderModel):
"""
raise NotImplementedError
def calc_tokens_price(self, tokens:int, message_type: MessageType):
def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
"""
calc tokens total price.
@@ -209,14 +209,14 @@ class BaseLLM(BaseProviderModel):
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
unit = self.price_config['unit']
unit = self.get_price_unit(message_type)
total_price = tokens * unit_price * unit
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
return total_price
def get_tokens_unit_price(self, message_type: MessageType):
def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
"""
get token price.
@@ -231,7 +231,23 @@ class BaseLLM(BaseProviderModel):
logging.debug(f"unit_price={unit_price}")
return unit_price
def get_currency(self):
def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
"""
get price unit.
:param message_type:
:return: decimal.Decimal('0.000001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
price_unit = self.price_config['unit']
else:
price_unit = self.price_config['unit']
price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"price_unit={price_unit}")
return price_unit
def get_currency(self) -> str:
"""
get token currency.