feat: remove llm client use (#1316)

This commit is contained in:
takatost
2023-10-12 03:02:53 +08:00
committed by GitHub
parent c007dbdc13
commit cbf095465c
14 changed files with 434 additions and 353 deletions

View File

@@ -1,6 +1,6 @@
import enum
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
from pydantic import BaseModel
@@ -9,6 +9,7 @@ class LLMRunResult(BaseModel):
prompt_tokens: int
completion_tokens: int
source: list = None
function_call: dict = None
class MessageType(enum.Enum):
@@ -20,6 +21,7 @@ class MessageType(enum.Enum):
class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN
content: str = ''
function_call: dict = None
def to_lc_messages(messages: list[PromptMessage]):
@@ -28,7 +30,10 @@ def to_lc_messages(messages: list[PromptMessage]):
if message.type == MessageType.HUMAN:
lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
lc_messages.append(AIMessage(content=message.content))
additional_kwargs = {}
if message.function_call:
additional_kwargs['function_call'] = message.function_call
lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs))
elif message.type == MessageType.SYSTEM:
lc_messages.append(SystemMessage(content=message.content))
@@ -41,9 +46,19 @@ def to_prompt_messages(messages: list[BaseMessage]):
if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
elif isinstance(message, AIMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
message_kwargs = {
'content': message.content,
'type': MessageType.ASSISTANT
}
if 'function_call' in message.additional_kwargs:
message_kwargs['function_call'] = message.additional_kwargs['function_call']
prompt_messages.append(PromptMessage(**message_kwargs))
elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
elif isinstance(message, FunctionMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
return prompt_messages