feat: remove llm client use (#1316)
This commit is contained in:
@@ -81,7 +81,20 @@ class AzureOpenAIModel(BaseLLM):
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
generate_kwargs = {
|
||||
'stop': stop,
|
||||
'callbacks': callbacks
|
||||
}
|
||||
|
||||
if isinstance(prompts, str):
|
||||
generate_kwargs['prompts'] = [prompts]
|
||||
else:
|
||||
generate_kwargs['messages'] = [prompts]
|
||||
|
||||
if 'functions' in kwargs:
|
||||
generate_kwargs['functions'] = kwargs['functions']
|
||||
|
||||
return self._client.generate(**generate_kwargs)
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
|
@@ -13,7 +13,8 @@ from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage,
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
|
||||
from core.helper import moderation
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
|
||||
to_lc_messages
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
@@ -157,8 +158,11 @@ class BaseLLM(BaseProviderModel):
|
||||
except Exception as ex:
|
||||
raise self.handle_exceptions(ex)
|
||||
|
||||
function_call = None
|
||||
if isinstance(result.generations[0][0], ChatGeneration):
|
||||
completion_content = result.generations[0][0].message.content
|
||||
if 'function_call' in result.generations[0][0].message.additional_kwargs:
|
||||
function_call = result.generations[0][0].message.additional_kwargs.get('function_call')
|
||||
else:
|
||||
completion_content = result.generations[0][0].text
|
||||
|
||||
@@ -191,7 +195,8 @@ class BaseLLM(BaseProviderModel):
|
||||
return LLMRunResult(
|
||||
content=completion_content,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens
|
||||
completion_tokens=completion_tokens,
|
||||
function_call=function_call
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -442,16 +447,7 @@ class BaseLLM(BaseProviderModel):
|
||||
if len(messages) == 0:
|
||||
return []
|
||||
|
||||
chat_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.HUMAN:
|
||||
chat_messages.append(HumanMessage(content=message.content))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
chat_messages.append(AIMessage(content=message.content))
|
||||
elif message.type == MessageType.SYSTEM:
|
||||
chat_messages.append(SystemMessage(content=message.content))
|
||||
|
||||
return chat_messages
|
||||
return to_lc_messages(messages)
|
||||
|
||||
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
|
||||
"""
|
||||
|
@@ -106,7 +106,21 @@ class OpenAIModel(BaseLLM):
|
||||
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
|
||||
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
generate_kwargs = {
|
||||
'stop': stop,
|
||||
'callbacks': callbacks
|
||||
}
|
||||
|
||||
if isinstance(prompts, str):
|
||||
generate_kwargs['prompts'] = [prompts]
|
||||
else:
|
||||
generate_kwargs['messages'] = [prompts]
|
||||
|
||||
if 'functions' in kwargs:
|
||||
generate_kwargs['functions'] = kwargs['functions']
|
||||
|
||||
return self._client.generate(**generate_kwargs)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user