feat: optimize template parse (#460)

This commit is contained in:
John Wang
2023-06-27 15:30:38 +08:00
committed by GitHub
parent df5763be37
commit c720f831af
7 changed files with 83 additions and 44 deletions

View File

@@ -23,7 +23,7 @@ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
ReadOnlyConversationTokenDBStringBufferSharedMemory
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import OutLinePromptTemplate
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.model import App, AppModelConfig, Account, Conversation, Message
@@ -35,6 +35,8 @@ class Completion:
"""
errors: ProviderTokenNotInitError
"""
query = PromptBuilder.process_template(query)
memory = None
if conversation:
# get memory of conversation (read-only)
@@ -141,18 +143,17 @@ class Completion:
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
# disable template string in query
query_params = OutLinePromptTemplate.from_template(template=query).input_variables
if query_params:
for query_param in query_params:
if query_param not in inputs:
inputs[query_param] = '{' + query_param + '}'
# query_params = JinjaPromptTemplate.from_template(template=query).input_variables
# if query_params:
# for query_param in query_params:
# if query_param not in inputs:
# inputs[query_param] = '{{' + query_param + '}}'
pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
if mode == 'completion':
prompt_template = OutLinePromptTemplate.from_template(
prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following CONTEXT as your learned knowledge:
[CONTEXT]
{context}
{{context}}
[END CONTEXT]
When answer to user:
@@ -162,16 +163,16 @@ Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
""" if chain_output else "")
+ (pre_prompt + "\n" if pre_prompt else "")
+ "{query}\n"
+ "{{query}}\n"
)
if chain_output:
inputs['context'] = chain_output
context_params = OutLinePromptTemplate.from_template(template=chain_output).input_variables
if context_params:
for context_param in context_params:
if context_param not in inputs:
inputs[context_param] = '{' + context_param + '}'
# context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables
# if context_params:
# for context_param in context_params:
# if context_param not in inputs:
# inputs[context_param] = '{{' + context_param + '}}'
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_content = prompt_template.format(
@@ -195,7 +196,7 @@ And answer according to the language of the user's question.
if pre_prompt:
pre_prompt_inputs = {k: inputs[k] for k in
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if pre_prompt_inputs:
@@ -205,7 +206,7 @@ And answer according to the language of the user's question.
human_inputs['context'] = chain_output
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
[CONTEXT]
{context}
{{context}}
[END CONTEXT]
When answer to user:
@@ -218,7 +219,7 @@ And answer according to the language of the user's question.
if pre_prompt:
human_message_prompt += pre_prompt
query_prompt = "\nHuman: {query}\nAI: "
query_prompt = "\nHuman: {{query}}\nAI: "
if memory:
# append chat histories
@@ -234,11 +235,11 @@ And answer according to the language of the user's question.
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
# disable template string in query
histories_params = OutLinePromptTemplate.from_template(template=histories).input_variables
if histories_params:
for histories_param in histories_params:
if histories_param not in human_inputs:
human_inputs[histories_param] = '{' + histories_param + '}'
# histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables
# if histories_params:
# for histories_param in histories_params:
# if histories_param not in human_inputs:
# human_inputs[histories_param] = '{{' + histories_param + '}}'
human_message_prompt += "\n\n" + histories