feat: upgrade langchain (#430)

Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
John Wang
2023-06-25 16:49:14 +08:00
committed by GitHub
parent 1dee5de9b4
commit 3241e4015b
91 changed files with 2703 additions and 3153 deletions

View File

@@ -1,17 +1,18 @@
import logging
from typing import Optional, List, Union, Tuple
from langchain.callbacks import CallbackManager
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chat_models.base import BaseChatModel
from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
from langchain.schema import BaseMessage, HumanMessage
from requests.exceptions import ChunkedEncodingError
from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError
from core.llm.llm_builder import LLMBuilder
from core.chain.main_chain_builder import MainChainBuilder
@@ -34,8 +35,6 @@ class Completion:
"""
errors: ProviderTokenNotInitError
"""
cls.validate_query_tokens(app.tenant_id, app_model_config, query)
memory = None
if conversation:
# get memory of conversation (read-only)
@@ -48,6 +47,14 @@ class Completion:
inputs = conversation.inputs
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
tenant_id=app.tenant_id,
app_model_config=app_model_config,
query=query,
inputs=inputs
)
conversation_message_task = ConversationMessageTask(
task_id=task_id,
app=app,
@@ -64,6 +71,7 @@ class Completion:
main_chain = MainChainBuilder.to_langchain_components(
tenant_id=app.tenant_id,
agent_mode=app_model_config.agent_mode_dict,
rest_tokens=rest_tokens_for_context_and_memory,
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
conversation_message_task=conversation_message_task
)
@@ -115,7 +123,7 @@ class Completion:
memory=memory
)
final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task)
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens(
final_llm=final_llm,
@@ -247,16 +255,14 @@ And answer according to the language of the user's question.
return messages, ['\nHuman:']
@classmethod
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool,
conversation_message_task: ConversationMessageTask) -> CallbackManager:
def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool,
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming:
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
else:
callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()]
return CallbackManager(callback_handlers)
return [llm_callback_handler, DifyStdOutCallbackHandler()]
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
@@ -293,7 +299,8 @@ And answer according to the language of the user's question.
return memory
@classmethod
def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str):
def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
query: str, inputs: dict) -> int:
llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id,
model=app_model_config.model_dict
@@ -302,8 +309,26 @@ And answer according to the language of the user's question.
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
max_tokens = llm.max_tokens
if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0:
raise LLMBadRequestError("Query is too long")
# get prompt without memory and context
prompt, _ = cls.get_main_llm_prompt(
mode=mode,
llm=llm,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
chain_output=None,
memory=None
)
prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
else llm.get_num_tokens_from_messages(prompt)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size.")
return rest_tokens
@classmethod
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
@@ -360,7 +385,7 @@ And answer according to the language of the user's question.
streaming=streaming
)
llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task)
llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens(
final_llm=llm,