feat: advanced prompt backend (#1301)
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
@@ -3,7 +3,7 @@ import uuid
|
||||
|
||||
from core.agent.agent_executor import PlanningStrategy
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelMode
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
@@ -34,40 +34,28 @@ class AppModelConfigService:
|
||||
# max_tokens
|
||||
if 'max_tokens' not in cp:
|
||||
cp["max_tokens"] = 512
|
||||
#
|
||||
# if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
|
||||
# llm_constant.max_context_token_length[model_name]:
|
||||
# raise ValueError(
|
||||
# "max_tokens must be an integer greater than 0 "
|
||||
# "and not exceeding the maximum value of the corresponding model")
|
||||
#
|
||||
|
||||
# temperature
|
||||
if 'temperature' not in cp:
|
||||
cp["temperature"] = 1
|
||||
#
|
||||
# if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
|
||||
# raise ValueError("temperature must be a float between 0 and 2")
|
||||
#
|
||||
|
||||
# top_p
|
||||
if 'top_p' not in cp:
|
||||
cp["top_p"] = 1
|
||||
|
||||
# if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
|
||||
# raise ValueError("top_p must be a float between 0 and 2")
|
||||
#
|
||||
# presence_penalty
|
||||
if 'presence_penalty' not in cp:
|
||||
cp["presence_penalty"] = 0
|
||||
|
||||
# if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
|
||||
# raise ValueError("presence_penalty must be a float between -2 and 2")
|
||||
#
|
||||
# presence_penalty
|
||||
if 'frequency_penalty' not in cp:
|
||||
cp["frequency_penalty"] = 0
|
||||
|
||||
# if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
|
||||
# raise ValueError("frequency_penalty must be a float between -2 and 2")
|
||||
# stop
|
||||
if 'stop' not in cp:
|
||||
cp["stop"] = []
|
||||
elif not isinstance(cp["stop"], list):
|
||||
raise ValueError("stop in model.completion_params must be of list type")
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_cp = {
|
||||
@@ -75,7 +63,8 @@ class AppModelConfigService:
|
||||
"temperature": cp["temperature"],
|
||||
"top_p": cp["top_p"],
|
||||
"presence_penalty": cp["presence_penalty"],
|
||||
"frequency_penalty": cp["frequency_penalty"]
|
||||
"frequency_penalty": cp["frequency_penalty"],
|
||||
"stop": cp["stop"]
|
||||
}
|
||||
|
||||
return filtered_cp
|
||||
@@ -211,6 +200,10 @@ class AppModelConfigService:
|
||||
model_ids = [m['id'] for m in model_list]
|
||||
if config["model"]["name"] not in model_ids:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
# model.mode
|
||||
if 'mode' not in config['model'] or not config['model']["mode"]:
|
||||
config['model']["mode"] = ""
|
||||
|
||||
# model.completion_params
|
||||
if 'completion_params' not in config["model"]:
|
||||
@@ -339,6 +332,9 @@ class AppModelConfigService:
|
||||
# dataset_query_variable
|
||||
AppModelConfigService.is_dataset_query_variable_valid(config, mode)
|
||||
|
||||
# advanced prompt validation
|
||||
AppModelConfigService.is_advanced_prompt_valid(config, mode)
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {
|
||||
"opening_statement": config["opening_statement"],
|
||||
@@ -351,12 +347,17 @@ class AppModelConfigService:
|
||||
"model": {
|
||||
"provider": config["model"]["provider"],
|
||||
"name": config["model"]["name"],
|
||||
"mode": config['model']["mode"],
|
||||
"completion_params": config["model"]["completion_params"]
|
||||
},
|
||||
"user_input_form": config["user_input_form"],
|
||||
"dataset_query_variable": config.get('dataset_query_variable'),
|
||||
"pre_prompt": config["pre_prompt"],
|
||||
"agent_mode": config["agent_mode"]
|
||||
"agent_mode": config["agent_mode"],
|
||||
"prompt_type": config["prompt_type"],
|
||||
"chat_prompt_config": config["chat_prompt_config"],
|
||||
"completion_prompt_config": config["completion_prompt_config"],
|
||||
"dataset_configs": config["dataset_configs"]
|
||||
}
|
||||
|
||||
return filtered_config
|
||||
@@ -375,4 +376,51 @@ class AppModelConfigService:
|
||||
|
||||
if dataset_exists and not dataset_query_variable:
|
||||
raise ValueError("Dataset query variable is required when dataset is exist")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:
|
||||
# prompt_type
|
||||
if 'prompt_type' not in config or not config["prompt_type"]:
|
||||
config["prompt_type"] = "simple"
|
||||
|
||||
if config['prompt_type'] not in ['simple', 'advanced']:
|
||||
raise ValueError("prompt_type must be in ['simple', 'advanced']")
|
||||
|
||||
# chat_prompt_config
|
||||
if 'chat_prompt_config' not in config or not config["chat_prompt_config"]:
|
||||
config["chat_prompt_config"] = {}
|
||||
|
||||
if not isinstance(config["chat_prompt_config"], dict):
|
||||
raise ValueError("chat_prompt_config must be of object type")
|
||||
|
||||
# completion_prompt_config
|
||||
if 'completion_prompt_config' not in config or not config["completion_prompt_config"]:
|
||||
config["completion_prompt_config"] = {}
|
||||
|
||||
if not isinstance(config["completion_prompt_config"], dict):
|
||||
raise ValueError("completion_prompt_config must be of object type")
|
||||
|
||||
# dataset_configs
|
||||
if 'dataset_configs' not in config or not config["dataset_configs"]:
|
||||
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if config['prompt_type'] == 'advanced':
|
||||
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
|
||||
raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced")
|
||||
|
||||
if config['model']["mode"] not in ['chat', 'completion']:
|
||||
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
|
||||
|
||||
if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value:
|
||||
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
|
||||
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
|
||||
|
||||
if not user_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
|
||||
|
||||
if not assistant_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
|
||||
|
Reference in New Issue
Block a user