Fix/human in answer (#174)
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, List, Union
|
from typing import Optional, List, Union, Tuple
|
||||||
|
|
||||||
from langchain.callbacks import CallbackManager
|
from langchain.callbacks import CallbackManager
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
@@ -97,7 +97,7 @@ class Completion:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# get llm prompt
|
# get llm prompt
|
||||||
prompt = cls.get_main_llm_prompt(
|
prompt, stop_words = cls.get_main_llm_prompt(
|
||||||
mode=mode,
|
mode=mode,
|
||||||
llm=final_llm,
|
llm=final_llm,
|
||||||
pre_prompt=app_model_config.pre_prompt,
|
pre_prompt=app_model_config.pre_prompt,
|
||||||
@@ -115,7 +115,7 @@ class Completion:
|
|||||||
mode=mode
|
mode=mode
|
||||||
)
|
)
|
||||||
|
|
||||||
response = final_llm.generate([prompt])
|
response = final_llm.generate([prompt], stop_words)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@@ -123,7 +123,7 @@ class Completion:
|
|||||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
|
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
|
||||||
chain_output: Optional[str],
|
chain_output: Optional[str],
|
||||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||||
Union[str | List[BaseMessage]]:
|
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
|
||||||
# disable template string in query
|
# disable template string in query
|
||||||
query_params = OutLinePromptTemplate.from_template(template=query).input_variables
|
query_params = OutLinePromptTemplate.from_template(template=query).input_variables
|
||||||
if query_params:
|
if query_params:
|
||||||
@@ -165,9 +165,9 @@ And answer according to the language of the user's question.
|
|||||||
|
|
||||||
if isinstance(llm, BaseChatModel):
|
if isinstance(llm, BaseChatModel):
|
||||||
# use chat llm as completion model
|
# use chat llm as completion model
|
||||||
return [HumanMessage(content=prompt_content)]
|
return [HumanMessage(content=prompt_content)], None
|
||||||
else:
|
else:
|
||||||
return prompt_content
|
return prompt_content, None
|
||||||
else:
|
else:
|
||||||
messages: List[BaseMessage] = []
|
messages: List[BaseMessage] = []
|
||||||
|
|
||||||
@@ -236,7 +236,7 @@ And answer according to the language of the user's question.
|
|||||||
|
|
||||||
messages.append(human_message)
|
messages.append(human_message)
|
||||||
|
|
||||||
return messages
|
return messages, ['\nHuman:']
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||||
@@ -323,7 +323,7 @@ And answer according to the language of the user's question.
|
|||||||
)
|
)
|
||||||
|
|
||||||
# get llm prompt
|
# get llm prompt
|
||||||
original_prompt = cls.get_main_llm_prompt(
|
original_prompt, _ = cls.get_main_llm_prompt(
|
||||||
mode="completion",
|
mode="completion",
|
||||||
llm=llm,
|
llm=llm,
|
||||||
pre_prompt=pre_prompt,
|
pre_prompt=pre_prompt,
|
||||||
|
Reference in New Issue
Block a user