feat: add baichuan prompt (#985)

This commit is contained in:
takatost
2023-08-24 10:22:36 +08:00
committed by GitHub
parent 9b247fccd4
commit 2c30d19cbe
9 changed files with 213 additions and 130 deletions

View File

@@ -130,13 +130,12 @@ class Completion:
fake_response = agent_execute_result.output
# get llm prompt
prompt_messages, stop_words = cls.get_main_llm_prompt(
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
agent_execute_result=agent_execute_result,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
@@ -154,113 +153,6 @@ class Completion:
return response
@classmethod
def get_main_llm_prompt(cls, mode: str, model: dict,
pre_prompt: str, query: str, inputs: dict,
agent_execute_result: Optional[AgentExecuteResult],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
""" if agent_execute_result else "")
+ (pre_prompt + "\n" if pre_prompt else "")
+ "{{query}}\n"
)
if agent_execute_result:
inputs['context'] = agent_execute_result.output
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_content = prompt_template.format(
query=query,
**prompt_inputs
)
return [PromptMessage(content=prompt_content)], None
else:
messages: List[BaseMessage] = []
human_inputs = {
"query": query
}
human_message_prompt = ""
if pre_prompt:
pre_prompt_inputs = {k: inputs[k] for k in
JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if pre_prompt_inputs:
human_inputs.update(pre_prompt_inputs)
if agent_execute_result:
human_inputs['context'] = agent_execute_result.output
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
"""
if pre_prompt:
human_message_prompt += pre_prompt
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
if memory:
# append chat histories
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt + query_prompt,
inputs=human_inputs
)
if memory.model_instance.model_rules.max_tokens.max:
curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
human_message_prompt += "\n\n" if human_message_prompt else ""
human_message_prompt += "Here is the chat histories between human and assistant, " \
"inside <histories></histories> XML tags.\n\n<histories>\n"
human_message_prompt += histories + "\n</histories>"
human_message_prompt += query_prompt
# construct main prompt
human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt,
inputs=human_inputs
)
messages.append(human_message)
for message in messages:
message.content = re.sub(r'<\|.*?\|>', '', message.content)
return to_prompt_messages(messages), ['\nHuman:', '</histories>']
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> str:
@@ -307,13 +199,12 @@ And answer according to the language of the user's question.
max_tokens = 0
# get prompt without memory and context
prompt_messages, _ = cls.get_main_llm_prompt(
prompt_messages, _ = model_instance.get_prompt(
mode=mode,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
agent_execute_result=None,
query=query,
context=None,
memory=None
)
@@ -358,13 +249,12 @@ And answer according to the language of the user's question.
)
# get llm prompt
old_prompt_messages, _ = cls.get_main_llm_prompt(
mode="completion",
model=app_model_config.model_dict,
old_prompt_messages, _ = final_model_instance.get_prompt(
mode='completion',
pre_prompt=pre_prompt,
query=message.query,
inputs=message.inputs,
agent_execute_result=None,
query=message.query,
context=None,
memory=None
)