feat: optimize template parse (#460)
This commit is contained in:
@@ -23,7 +23,7 @@ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBStringBufferSharedMemory
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import OutLinePromptTemplate
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
from models.model import App, AppModelConfig, Account, Conversation, Message
|
||||
|
||||
@@ -35,6 +35,8 @@ class Completion:
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
query = PromptBuilder.process_template(query)
|
||||
|
||||
memory = None
|
||||
if conversation:
|
||||
# get memory of conversation (read-only)
|
||||
@@ -141,18 +143,17 @@ class Completion:
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
|
||||
# disable template string in query
|
||||
query_params = OutLinePromptTemplate.from_template(template=query).input_variables
|
||||
if query_params:
|
||||
for query_param in query_params:
|
||||
if query_param not in inputs:
|
||||
inputs[query_param] = '{' + query_param + '}'
|
||||
# query_params = JinjaPromptTemplate.from_template(template=query).input_variables
|
||||
# if query_params:
|
||||
# for query_param in query_params:
|
||||
# if query_param not in inputs:
|
||||
# inputs[query_param] = '{{' + query_param + '}}'
|
||||
|
||||
pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
|
||||
if mode == 'completion':
|
||||
prompt_template = OutLinePromptTemplate.from_template(
|
||||
prompt_template = JinjaPromptTemplate.from_template(
|
||||
template=("""Use the following CONTEXT as your learned knowledge:
|
||||
[CONTEXT]
|
||||
{context}
|
||||
{{context}}
|
||||
[END CONTEXT]
|
||||
|
||||
When answer to user:
|
||||
@@ -162,16 +163,16 @@ Avoid mentioning that you obtained the information from the context.
|
||||
And answer according to the language of the user's question.
|
||||
""" if chain_output else "")
|
||||
+ (pre_prompt + "\n" if pre_prompt else "")
|
||||
+ "{query}\n"
|
||||
+ "{{query}}\n"
|
||||
)
|
||||
|
||||
if chain_output:
|
||||
inputs['context'] = chain_output
|
||||
context_params = OutLinePromptTemplate.from_template(template=chain_output).input_variables
|
||||
if context_params:
|
||||
for context_param in context_params:
|
||||
if context_param not in inputs:
|
||||
inputs[context_param] = '{' + context_param + '}'
|
||||
# context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables
|
||||
# if context_params:
|
||||
# for context_param in context_params:
|
||||
# if context_param not in inputs:
|
||||
# inputs[context_param] = '{{' + context_param + '}}'
|
||||
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
||||
prompt_content = prompt_template.format(
|
||||
@@ -195,7 +196,7 @@ And answer according to the language of the user's question.
|
||||
|
||||
if pre_prompt:
|
||||
pre_prompt_inputs = {k: inputs[k] for k in
|
||||
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
|
||||
JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
|
||||
if k in inputs}
|
||||
|
||||
if pre_prompt_inputs:
|
||||
@@ -205,7 +206,7 @@ And answer according to the language of the user's question.
|
||||
human_inputs['context'] = chain_output
|
||||
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
|
||||
[CONTEXT]
|
||||
{context}
|
||||
{{context}}
|
||||
[END CONTEXT]
|
||||
|
||||
When answer to user:
|
||||
@@ -218,7 +219,7 @@ And answer according to the language of the user's question.
|
||||
if pre_prompt:
|
||||
human_message_prompt += pre_prompt
|
||||
|
||||
query_prompt = "\nHuman: {query}\nAI: "
|
||||
query_prompt = "\nHuman: {{query}}\nAI: "
|
||||
|
||||
if memory:
|
||||
# append chat histories
|
||||
@@ -234,11 +235,11 @@ And answer according to the language of the user's question.
|
||||
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||
|
||||
# disable template string in query
|
||||
histories_params = OutLinePromptTemplate.from_template(template=histories).input_variables
|
||||
if histories_params:
|
||||
for histories_param in histories_params:
|
||||
if histories_param not in human_inputs:
|
||||
human_inputs[histories_param] = '{' + histories_param + '}'
|
||||
# histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables
|
||||
# if histories_params:
|
||||
# for histories_param in histories_params:
|
||||
# if histories_param not in human_inputs:
|
||||
# human_inputs[histories_param] = '{{' + histories_param + '}}'
|
||||
|
||||
human_message_prompt += "\n\n" + histories
|
||||
|
||||
|
Reference in New Issue
Block a user