feat: add baichuan prompt (#985)

This commit is contained in:
takatost
2023-08-24 10:22:36 +08:00
committed by GitHub
parent 9b247fccd4
commit 2c30d19cbe
9 changed files with 213 additions and 130 deletions

View File

@@ -1,17 +1,24 @@
import json
import os
import re
from abc import abstractmethod
from typing import List, Optional, Any, Union
from typing import List, Optional, Any, Union, Tuple
import decimal
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.third_party.langchain.llms.fake import FakeLLM
import logging
logger = logging.getLogger(__name__)
@@ -76,13 +83,14 @@ class BaseLLM(BaseProviderModel):
def price_config(self) -> dict:
def get_or_default():
default_price_config = {
'prompt': decimal.Decimal('0'),
'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'),
'currency': 'USD'
}
'prompt': decimal.Decimal('0'),
'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'),
'currency': 'USD'
}
rules = self.model_provider.get_rules()
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
price_config = rules['price_config'][
self.base_model_name] if 'price_config' in rules else default_price_config
price_config = {
'prompt': decimal.Decimal(price_config['prompt']),
'completion': decimal.Decimal(price_config['completion']),
@@ -90,7 +98,7 @@ class BaseLLM(BaseProviderModel):
'currency': price_config['currency']
}
return price_config
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
logger.debug(f"model: {self.name} price_config: {self._price_config}")
@@ -158,7 +166,8 @@ class BaseLLM(BaseProviderModel):
total_tokens = result.llm_output['token_usage']['total_tokens']
else:
prompt_tokens = self.get_num_tokens(messages)
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
completion_tokens = self.get_num_tokens(
[PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
total_tokens = prompt_tokens + completion_tokens
self.model_provider.update_last_used()
@@ -293,6 +302,119 @@ class BaseLLM(BaseProviderModel):
def support_streaming(cls):
return False
def get_prompt(self, mode: str,
pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'common_completion'
else:
return 'common_chat'
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
context=context
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
pre_prompt_content = prompt_template.format(
**prompt_inputs
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n\n') if pre_prompt_content else ''
query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
if memory and 'histories_prompt' in prompt_rules:
# append chat histories
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=prompt + query_prompt,
inputs={
'query': query
}
)
if self.model_rules.max_tokens.max:
curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = self.model_kwargs.max_tokens
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format(
histories=histories
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
elif order == 'histories_prompt':
prompt += histories_prompt_content
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
query_prompt_content = prompt_template.format(
query=query
)
prompt += query_prompt_content
prompt = re.sub(r'<\|.*?\|>', '', prompt)
stops = prompt_rules.get('stops')
if stops is not None and len(stops) == 0:
stops = None
return prompt, stops
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
# Get the absolute path of the subdirectory
prompt_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
'prompt/generate_prompts')
json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
# Open the JSON file and read its content
with open(json_file_path, 'r') as json_file:
return json.load(json_file)
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> str:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
return external_context[memory_key]
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
if not model_mode:

View File

@@ -60,6 +60,15 @@ class HuggingfaceHubModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs

View File

@@ -49,6 +49,15 @@ class OpenLLMModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass

View File

@@ -59,6 +59,15 @@ class XinferenceModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass