From 4350bb9a00799e20e5cd963687080bfd09405ab8 Mon Sep 17 00:00:00 2001 From: John Wang Date: Tue, 23 May 2023 19:54:04 +0800 Subject: [PATCH] Fix/human in answer (#174) --- api/core/completion.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/api/core/completion.py b/api/core/completion.py index 47658f3bf..5e559ac7c 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Union +from typing import Optional, List, Union, Tuple from langchain.callbacks import CallbackManager from langchain.chat_models.base import BaseChatModel @@ -97,7 +97,7 @@ class Completion: ) # get llm prompt - prompt = cls.get_main_llm_prompt( + prompt, stop_words = cls.get_main_llm_prompt( mode=mode, llm=final_llm, pre_prompt=app_model_config.pre_prompt, @@ -115,7 +115,7 @@ class Completion: mode=mode ) - response = final_llm.generate([prompt]) + response = final_llm.generate([prompt], stop_words) return response @@ -123,7 +123,7 @@ class Completion: def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str], memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ - Union[str | List[BaseMessage]]: + Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]: # disable template string in query query_params = OutLinePromptTemplate.from_template(template=query).input_variables if query_params: @@ -165,9 +165,9 @@ And answer according to the language of the user's question. if isinstance(llm, BaseChatModel): # use chat llm as completion model - return [HumanMessage(content=prompt_content)] + return [HumanMessage(content=prompt_content)], None else: - return prompt_content + return prompt_content, None else: messages: List[BaseMessage] = [] @@ -236,7 +236,7 @@ And answer according to the language of the user's question. messages.append(human_message) - return messages + return messages, ['\nHuman:'] @classmethod def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], @@ -323,7 +323,7 @@ And answer according to the language of the user's question. ) # get llm prompt - original_prompt = cls.get_main_llm_prompt( + original_prompt, _ = cls.get_main_llm_prompt( mode="completion", llm=llm, pre_prompt=pre_prompt,