feat: advanced prompt backend (#1301)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Garfield Dai
2023-10-12 23:13:10 +08:00
committed by GitHub
parent 2d1cb076c6
commit 42a5b3ec17
61 changed files with 767 additions and 581 deletions

View File

@@ -13,13 +13,13 @@ class LLMRunResult(BaseModel):
class MessageType(enum.Enum):
HUMAN = 'human'
USER = 'user'
ASSISTANT = 'assistant'
SYSTEM = 'system'
class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN
type: MessageType = MessageType.USER
content: str = ''
function_call: dict = None
@@ -27,7 +27,7 @@ class PromptMessage(BaseModel):
def to_lc_messages(messages: list[PromptMessage]):
lc_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
if message.type == MessageType.USER:
lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
additional_kwargs = {}
@@ -44,7 +44,7 @@ def to_prompt_messages(messages: list[BaseMessage]):
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
elif isinstance(message, AIMessage):
message_kwargs = {
'content': message.content,
@@ -58,7 +58,7 @@ def to_prompt_messages(messages: list[BaseMessage]):
elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
elif isinstance(message, FunctionMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
return prompt_messages

View File

@@ -18,7 +18,7 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
from core.third_party.langchain.llms.fake import FakeLLM
import logging
@@ -232,7 +232,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return:
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
@@ -250,7 +250,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return: decimal.Decimal('0.0001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
@@ -265,7 +265,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return: decimal.Decimal('0.000001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
price_unit = self.price_config['unit']
else:
price_unit = self.price_config['unit']
@@ -330,6 +330,85 @@ class BaseLLM(BaseProviderModel):
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops
def get_advanced_prompt(self, app_mode: str,
app_model_config: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
model_mode = app_model_config.model_dict['mode']
conversation_histories_role = {}
raw_prompt_list = []
prompt_messages = []
if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
raw_prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
raw_prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
else:
raise Exception("app_mode or model_mode not support")
for prompt_item in raw_prompt_list:
prompt = prompt_item['text']
# set prompt template variables
prompt_template = PromptTemplateParser(template=prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
if '#context#' in prompt:
if context:
prompt_inputs['#context#'] = context
else:
prompt_inputs['#context#'] = ''
if '#query#' in prompt:
if query:
prompt_inputs['#query#'] = query
else:
prompt_inputs['#query#'] = ''
if '#histories#' in prompt:
if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
histories = self._get_history_messages_from_memory(memory, 2000)
prompt_inputs['#histories#'] = histories
else:
prompt_inputs['#histories#'] = ''
prompt = prompt_template.format(
prompt_inputs
)
prompt = re.sub(r'<\|.*?\|>', '', prompt)
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
memory.human_prefix = MessageType.USER.value
memory.ai_prefix = MessageType.ASSISTANT.value
histories = self._get_history_messages_list_from_memory(memory, 2000)
prompt_messages.extend(histories)
if app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
return prompt_messages
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'common_completion'
@@ -342,17 +421,17 @@ class BaseLLM(BaseProviderModel):
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
context=context
{'context': context}
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_template = PromptTemplateParser(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
pre_prompt_content = prompt_template.format(
**prompt_inputs
prompt_inputs
)
prompt = ''
@@ -385,10 +464,8 @@ class BaseLLM(BaseProviderModel):
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format(
histories=histories
)
prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format({'histories': histories})
prompt = ''
for order in prompt_rules['system_prompt_orders']:
@@ -399,10 +476,8 @@ class BaseLLM(BaseProviderModel):
elif order == 'histories_prompt':
prompt += histories_prompt_content
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
query_prompt_content = prompt_template.format(
query=query
)
prompt_template = PromptTemplateParser(template=query_prompt)
query_prompt_content = prompt_template.format({'query': query})
prompt += query_prompt_content
@@ -433,6 +508,16 @@ class BaseLLM(BaseProviderModel):
external_context = memory.load_memory_variables({})
return external_context[memory_key]
def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> List[PromptMessage]:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory.return_messages = True
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
memory.return_messages = False
return to_prompt_messages(external_context[memory_key])
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
if not model_mode: