feat: upgrade langchain (#430)

Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
John Wang
2023-06-25 16:49:14 +08:00
committed by GitHub
parent 1dee5de9b4
commit 3241e4015b
91 changed files with 2703 additions and 3153 deletions

View File

@@ -1,11 +1,9 @@
from typing import Optional, List
from typing import Optional, List, cast
from langchain.callbacks import SharedCallbackManager, CallbackManager
from langchain.chains import SequentialChain
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder
@@ -18,6 +16,7 @@ from models.dataset import Dataset
class MainChainBuilder:
@classmethod
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
rest_tokens: int,
conversation_message_task: ConversationMessageTask):
first_input_key = "input"
final_output_key = "output"
@@ -30,6 +29,7 @@ class MainChainBuilder:
tool_chains, chains_output_key = cls.get_agent_chains(
tenant_id=tenant_id,
agent_mode=agent_mode,
rest_tokens=rest_tokens,
memory=memory,
conversation_message_task=conversation_message_task
)
@@ -42,9 +42,8 @@ class MainChainBuilder:
return None
for chain in chains:
# do not add handler into singleton callback manager
if not isinstance(chain.callback_manager, SharedCallbackManager):
chain.callback_manager.add_handler(chain_callback_handler)
chain = cast(Chain, chain)
chain.callbacks.append(chain_callback_handler)
# build main chain
overall_chain = SequentialChain(
@@ -57,7 +56,9 @@ class MainChainBuilder:
return overall_chain
@classmethod
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int,
memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask):
# agent mode
chains = []
@@ -93,7 +94,8 @@ class MainChainBuilder:
tenant_id=tenant_id,
datasets=datasets,
conversation_message_task=conversation_message_task,
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
)
chains.append(multi_dataset_router_chain)