Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -2,11 +2,12 @@ import re
|
||||
import uuid
|
||||
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from core.prompt.prompt_transform import AppMode
|
||||
from core.agent.agent_executor import PlanningStrategy
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelMode
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
@@ -34,26 +35,6 @@ class AppModelConfigService:
|
||||
if not isinstance(cp, dict):
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
# max_tokens
|
||||
if 'max_tokens' not in cp:
|
||||
cp["max_tokens"] = 512
|
||||
|
||||
# temperature
|
||||
if 'temperature' not in cp:
|
||||
cp["temperature"] = 1
|
||||
|
||||
# top_p
|
||||
if 'top_p' not in cp:
|
||||
cp["top_p"] = 1
|
||||
|
||||
# presence_penalty
|
||||
if 'presence_penalty' not in cp:
|
||||
cp["presence_penalty"] = 0
|
||||
|
||||
# presence_penalty
|
||||
if 'frequency_penalty' not in cp:
|
||||
cp["frequency_penalty"] = 0
|
||||
|
||||
# stop
|
||||
if 'stop' not in cp:
|
||||
cp["stop"] = []
|
||||
@@ -63,20 +44,10 @@ class AppModelConfigService:
|
||||
if len(cp["stop"]) > 4:
|
||||
raise ValueError("stop sequences must be less than 4")
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_cp = {
|
||||
"max_tokens": cp["max_tokens"],
|
||||
"temperature": cp["temperature"],
|
||||
"top_p": cp["top_p"],
|
||||
"presence_penalty": cp["presence_penalty"],
|
||||
"frequency_penalty": cp["frequency_penalty"],
|
||||
"stop": cp["stop"]
|
||||
}
|
||||
|
||||
return filtered_cp
|
||||
return cp
|
||||
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, account: Account, config: dict, mode: str) -> dict:
|
||||
def validate_configuration(cls, tenant_id: str, account: Account, config: dict, app_mode: str) -> dict:
|
||||
# opening_statement
|
||||
if 'opening_statement' not in config or not config["opening_statement"]:
|
||||
config["opening_statement"] = ""
|
||||
@@ -140,21 +111,6 @@ class AppModelConfigService:
|
||||
if not isinstance(config["retriever_resource"]["enabled"], bool):
|
||||
raise ValueError("enabled in retriever_resource must be of boolean type")
|
||||
|
||||
# annotation reply
|
||||
if 'annotation_reply' not in config or not config["annotation_reply"]:
|
||||
config["annotation_reply"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["annotation_reply"], dict):
|
||||
raise ValueError("annotation_reply must be of dict type")
|
||||
|
||||
if "enabled" not in config["annotation_reply"] or not config["annotation_reply"]["enabled"]:
|
||||
config["annotation_reply"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["annotation_reply"]["enabled"], bool):
|
||||
raise ValueError("enabled in annotation_reply must be of boolean type")
|
||||
|
||||
# more_like_this
|
||||
if 'more_like_this' not in config or not config["more_like_this"]:
|
||||
config["more_like_this"] = {
|
||||
@@ -178,7 +134,8 @@ class AppModelConfigService:
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# model.provider
|
||||
model_provider_names = ModelProviderFactory.get_provider_names()
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
@@ -186,18 +143,29 @@ class AppModelConfigService:
|
||||
if 'name' not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, config["model"]["provider"])
|
||||
if not model_provider:
|
||||
provider_manager = ProviderManager()
|
||||
models = provider_manager.get_configurations(tenant_id).get_models(
|
||||
provider=config["model"]["provider"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
if not models:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
model_list = model_provider.get_supported_model_list(ModelType.TEXT_GENERATION)
|
||||
model_ids = [m['id'] for m in model_list]
|
||||
model_ids = [m.model for m in models]
|
||||
if config["model"]["name"] not in model_ids:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
model_mode = None
|
||||
for model in models:
|
||||
if model.model == config["model"]["name"]:
|
||||
model_mode = model.model_properties.get(ModelPropertyKey.MODE)
|
||||
break
|
||||
|
||||
# model.mode
|
||||
if 'mode' not in config['model'] or not config['model']["mode"]:
|
||||
config['model']["mode"] = ""
|
||||
if model_mode:
|
||||
config['model']["mode"] = model_mode
|
||||
else:
|
||||
config['model']["mode"] = "completion"
|
||||
|
||||
# model.completion_params
|
||||
if 'completion_params' not in config["model"]:
|
||||
@@ -319,10 +287,10 @@ class AppModelConfigService:
|
||||
raise ValueError("Dataset ID does not exist, please check your permission.")
|
||||
|
||||
# dataset_query_variable
|
||||
cls.is_dataset_query_variable_valid(config, mode)
|
||||
cls.is_dataset_query_variable_valid(config, app_mode)
|
||||
|
||||
# advanced prompt validation
|
||||
cls.is_advanced_prompt_valid(config, mode)
|
||||
cls.is_advanced_prompt_valid(config, app_mode)
|
||||
|
||||
# external data tools validation
|
||||
cls.is_external_data_tools_valid(tenant_id, config)
|
||||
@@ -340,7 +308,6 @@ class AppModelConfigService:
|
||||
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
|
||||
"speech_to_text": config["speech_to_text"],
|
||||
"retriever_resource": config["retriever_resource"],
|
||||
"annotation_reply": config["annotation_reply"],
|
||||
"more_like_this": config["more_like_this"],
|
||||
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
|
||||
"external_data_tools": config["external_data_tools"],
|
||||
@@ -507,7 +474,7 @@ class AppModelConfigService:
|
||||
if config['model']["mode"] not in ['chat', 'completion']:
|
||||
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
|
||||
|
||||
if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value:
|
||||
if app_mode == AppMode.CHAT.value and config['model']["mode"] == "completion":
|
||||
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
|
||||
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
|
||||
|
||||
@@ -517,7 +484,7 @@ class AppModelConfigService:
|
||||
if not assistant_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
|
||||
|
||||
if config['model']["mode"] == ModelMode.CHAT.value:
|
||||
if config['model']["mode"] == "chat":
|
||||
prompt_list = config['chat_prompt_config']['prompt']
|
||||
|
||||
if len(prompt_list) > 10:
|
||||
|
Reference in New Issue
Block a user