Fix/price calc (#862)
This commit is contained in:
@@ -31,6 +31,15 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name (not deployment)
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.credentials.get("base_model_name")
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
@@ -49,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to Azure OpenAI API.")
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
import decimal
|
||||
|
||||
import tiktoken
|
||||
from langchain.schema.language_model import _get_token_ids_default_method
|
||||
@@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseEmbedding(BaseProviderModel):
|
||||
name: str
|
||||
@@ -17,6 +19,65 @@ class BaseEmbedding(BaseProviderModel):
|
||||
super().__init__(model_provider, client)
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def price_config(self) -> dict:
|
||||
def get_or_default():
|
||||
default_price_config = {
|
||||
'prompt': decimal.Decimal('0'),
|
||||
'completion': decimal.Decimal('0'),
|
||||
'unit': decimal.Decimal('0'),
|
||||
'currency': 'USD'
|
||||
}
|
||||
rules = self.model_provider.get_rules()
|
||||
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
|
||||
price_config = {
|
||||
'prompt': decimal.Decimal(price_config['prompt']),
|
||||
'completion': decimal.Decimal(price_config['completion']),
|
||||
'unit': decimal.Decimal(price_config['unit']),
|
||||
'currency': price_config['currency']
|
||||
}
|
||||
return price_config
|
||||
|
||||
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
||||
|
||||
logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
||||
return self._price_config
|
||||
|
||||
def calc_tokens_price(self, tokens:int) -> decimal.Decimal:
|
||||
"""
|
||||
calc tokens total price.
|
||||
|
||||
:param tokens:
|
||||
:return: decimal.Decimal('0.0000001')
|
||||
"""
|
||||
unit_price = self._price_config['completion']
|
||||
unit = self._price_config['unit']
|
||||
total_price = tokens * unit_price * unit
|
||||
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
||||
return total_price
|
||||
|
||||
def get_tokens_unit_price(self) -> decimal.Decimal:
|
||||
"""
|
||||
get token price.
|
||||
|
||||
:return: decimal.Decimal('0.0001')
|
||||
|
||||
"""
|
||||
unit_price = self._price_config['completion']
|
||||
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logger.debug(f'unit_price:{unit_price}')
|
||||
return unit_price
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
get num tokens of text.
|
||||
@@ -29,11 +90,14 @@ class BaseEmbedding(BaseProviderModel):
|
||||
|
||||
return len(_get_token_ids_default_method(text))
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return 0
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
"""
|
||||
get token currency.
|
||||
|
||||
:return: get from price config, default 'USD'
|
||||
"""
|
||||
currency = self._price_config['currency']
|
||||
return currency
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
|
@@ -22,9 +22,6 @@ class MinimaxEmbedding(BaseEmbedding):
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
@@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding):
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to OpenAI API.")
|
||||
|
@@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding):
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Replicate: {str(ex)}")
|
||||
|
Reference in New Issue
Block a user