feat: optimize minimax llm call (#1312)

This commit is contained in:
takatost
2023-10-11 20:17:41 +08:00
committed by GitHub
parent c536f85b2e
commit 2851a9f04e
3 changed files with 287 additions and 12 deletions

View File

@@ -1,26 +1,23 @@
import decimal
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import Minimax
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
class MinimaxModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return Minimax(
return MinimaxChatLLM(
model=self.name,
model_kwargs={
'stream': False
},
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
@@ -49,7 +46,7 @@ class MinimaxModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def get_currency(self):
return 'RMB'
@@ -65,3 +62,7 @@ class MinimaxModel(BaseLLM):
return LLMBadRequestError(f"Minimax: {str(ex)}")
else:
return ex
@property
def support_streaming(self):
return True