feat: advanced prompt backend (#1301)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Garfield Dai
2023-10-12 23:13:10 +08:00
committed by GitHub
parent 2d1cb076c6
commit 42a5b3ec17
61 changed files with 767 additions and 581 deletions

View File

@@ -1,38 +1,24 @@
import re
from langchain.schema import BaseMessage, SystemMessage, AIMessage, HumanMessage
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
from langchain.schema import BaseMessage
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
class PromptBuilder:
@classmethod
def parse_prompt(cls, prompt: str, inputs: dict) -> str:
prompt_template = PromptTemplateParser(prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
prompt = prompt_template.format(prompt_inputs)
return prompt
@classmethod
def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
system_message = system_prompt_template.format(**prompt_inputs)
return system_message
return SystemMessage(content=cls.parse_prompt(prompt_content, inputs))
@classmethod
def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
ai_message = ai_prompt_template.format(**prompt_inputs)
return ai_message
return AIMessage(content=cls.parse_prompt(prompt_content, inputs))
@classmethod
def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
human_message = human_prompt_template.format(**inputs)
return human_message
@classmethod
def process_template(cls, template: str):
processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template)
# processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
# processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
return processed_template
return HumanMessage(content=cls.parse_prompt(prompt_content, inputs))