feat: claude api support (#572)

This commit is contained in:
John Wang
2023-07-17 00:14:19 +08:00
committed by GitHub
parent 510389909c
commit 7599f79a17
52 changed files with 637 additions and 349 deletions

View File

@@ -1,7 +1,7 @@
from typing import Any, List, Dict, Union
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
@@ -12,8 +12,8 @@ from models.model import Conversation, Message
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
conversation: Conversation
human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: Union[StreamableChatOpenAI | StreamableOpenAI]
ai_prefix: str = "Assistant"
llm: BaseLanguageModel
memory_key: str = "chat_history"
max_token_limit: int = 2000
message_limit: int = 10
@@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
return chat_messages
# prune the chat message if it exceeds the max token limit
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit and chat_messages:
pruned_memory.append(chat_messages.pop(0))
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
return chat_messages