Model Runtime (#1858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
takatost
2024-01-02 23:42:00 +08:00
committed by GitHub
parent e91dd28a76
commit d069c668f8
807 changed files with 171310 additions and 23806 deletions

View File

@@ -2,11 +2,12 @@ import re
import uuid
from core.external_data_tool.factory import ExternalDataToolFactory
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.moderation.factory import ModerationFactory
from core.prompt.prompt_transform import AppMode
from core.agent.agent_executor import PlanningStrategy
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType, ModelMode
from core.provider_manager import ProviderManager
from models.account import Account
from services.dataset_service import DatasetService
@@ -34,26 +35,6 @@ class AppModelConfigService:
if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type")
# max_tokens
if 'max_tokens' not in cp:
cp["max_tokens"] = 512
# temperature
if 'temperature' not in cp:
cp["temperature"] = 1
# top_p
if 'top_p' not in cp:
cp["top_p"] = 1
# presence_penalty
if 'presence_penalty' not in cp:
cp["presence_penalty"] = 0
# presence_penalty
if 'frequency_penalty' not in cp:
cp["frequency_penalty"] = 0
# stop
if 'stop' not in cp:
cp["stop"] = []
@@ -63,20 +44,10 @@ class AppModelConfigService:
if len(cp["stop"]) > 4:
raise ValueError("stop sequences must be less than 4")
# Filter out extra parameters
filtered_cp = {
"max_tokens": cp["max_tokens"],
"temperature": cp["temperature"],
"top_p": cp["top_p"],
"presence_penalty": cp["presence_penalty"],
"frequency_penalty": cp["frequency_penalty"],
"stop": cp["stop"]
}
return filtered_cp
return cp
@classmethod
def validate_configuration(cls, tenant_id: str, account: Account, config: dict, mode: str) -> dict:
def validate_configuration(cls, tenant_id: str, account: Account, config: dict, app_mode: str) -> dict:
# opening_statement
if 'opening_statement' not in config or not config["opening_statement"]:
config["opening_statement"] = ""
@@ -140,21 +111,6 @@ class AppModelConfigService:
if not isinstance(config["retriever_resource"]["enabled"], bool):
raise ValueError("enabled in retriever_resource must be of boolean type")
# annotation reply
if 'annotation_reply' not in config or not config["annotation_reply"]:
config["annotation_reply"] = {
"enabled": False
}
if not isinstance(config["annotation_reply"], dict):
raise ValueError("annotation_reply must be of dict type")
if "enabled" not in config["annotation_reply"] or not config["annotation_reply"]["enabled"]:
config["annotation_reply"]["enabled"] = False
if not isinstance(config["annotation_reply"]["enabled"], bool):
raise ValueError("enabled in annotation_reply must be of boolean type")
# more_like_this
if 'more_like_this' not in config or not config["more_like_this"]:
config["more_like_this"] = {
@@ -178,7 +134,8 @@ class AppModelConfigService:
raise ValueError("model must be of object type")
# model.provider
model_provider_names = ModelProviderFactory.get_provider_names()
provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities]
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
@@ -186,18 +143,29 @@ class AppModelConfigService:
if 'name' not in config["model"]:
raise ValueError("model.name is required")
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, config["model"]["provider"])
if not model_provider:
provider_manager = ProviderManager()
models = provider_manager.get_configurations(tenant_id).get_models(
provider=config["model"]["provider"],
model_type=ModelType.LLM
)
if not models:
raise ValueError("model.name must be in the specified model list")
model_list = model_provider.get_supported_model_list(ModelType.TEXT_GENERATION)
model_ids = [m['id'] for m in model_list]
model_ids = [m.model for m in models]
if config["model"]["name"] not in model_ids:
raise ValueError("model.name must be in the specified model list")
model_mode = None
for model in models:
if model.model == config["model"]["name"]:
model_mode = model.model_properties.get(ModelPropertyKey.MODE)
break
# model.mode
if 'mode' not in config['model'] or not config['model']["mode"]:
config['model']["mode"] = ""
if model_mode:
config['model']["mode"] = model_mode
else:
config['model']["mode"] = "completion"
# model.completion_params
if 'completion_params' not in config["model"]:
@@ -319,10 +287,10 @@ class AppModelConfigService:
raise ValueError("Dataset ID does not exist, please check your permission.")
# dataset_query_variable
cls.is_dataset_query_variable_valid(config, mode)
cls.is_dataset_query_variable_valid(config, app_mode)
# advanced prompt validation
cls.is_advanced_prompt_valid(config, mode)
cls.is_advanced_prompt_valid(config, app_mode)
# external data tools validation
cls.is_external_data_tools_valid(tenant_id, config)
@@ -340,7 +308,6 @@ class AppModelConfigService:
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
"speech_to_text": config["speech_to_text"],
"retriever_resource": config["retriever_resource"],
"annotation_reply": config["annotation_reply"],
"more_like_this": config["more_like_this"],
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
"external_data_tools": config["external_data_tools"],
@@ -507,7 +474,7 @@ class AppModelConfigService:
if config['model']["mode"] not in ['chat', 'completion']:
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value:
if app_mode == AppMode.CHAT.value and config['model']["mode"] == "completion":
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
@@ -517,7 +484,7 @@ class AppModelConfigService:
if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
if config['model']["mode"] == ModelMode.CHAT.value:
if config['model']["mode"] == "chat":
prompt_list = config['chat_prompt_config']['prompt']
if len(prompt_list) > 10: