feat: server multi models support (#799)
This commit is contained in:
@@ -2,42 +2,11 @@ import re
|
||||
import uuid
|
||||
|
||||
from core.agent.agent_executor import PlanningStrategy
|
||||
from core.constant import llm_constant
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
|
||||
MODEL_PROVIDERS = [
|
||||
'openai',
|
||||
'anthropic',
|
||||
]
|
||||
|
||||
MODELS_BY_APP_MODE = {
|
||||
'chat': [
|
||||
'claude-instant-1',
|
||||
'claude-2',
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-16k',
|
||||
],
|
||||
'completion': [
|
||||
'claude-instant-1',
|
||||
'claude-2',
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-16k',
|
||||
'text-davinci-003',
|
||||
]
|
||||
}
|
||||
|
||||
SUPPORT_AGENT_MODELS = [
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
]
|
||||
|
||||
SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
|
||||
|
||||
@@ -65,40 +34,40 @@ class AppModelConfigService:
|
||||
# max_tokens
|
||||
if 'max_tokens' not in cp:
|
||||
cp["max_tokens"] = 512
|
||||
|
||||
if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
|
||||
llm_constant.max_context_token_length[model_name]:
|
||||
raise ValueError(
|
||||
"max_tokens must be an integer greater than 0 "
|
||||
"and not exceeding the maximum value of the corresponding model")
|
||||
|
||||
#
|
||||
# if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
|
||||
# llm_constant.max_context_token_length[model_name]:
|
||||
# raise ValueError(
|
||||
# "max_tokens must be an integer greater than 0 "
|
||||
# "and not exceeding the maximum value of the corresponding model")
|
||||
#
|
||||
# temperature
|
||||
if 'temperature' not in cp:
|
||||
cp["temperature"] = 1
|
||||
|
||||
if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
|
||||
raise ValueError("temperature must be a float between 0 and 2")
|
||||
|
||||
#
|
||||
# if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
|
||||
# raise ValueError("temperature must be a float between 0 and 2")
|
||||
#
|
||||
# top_p
|
||||
if 'top_p' not in cp:
|
||||
cp["top_p"] = 1
|
||||
|
||||
if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
|
||||
raise ValueError("top_p must be a float between 0 and 2")
|
||||
|
||||
# if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
|
||||
# raise ValueError("top_p must be a float between 0 and 2")
|
||||
#
|
||||
# presence_penalty
|
||||
if 'presence_penalty' not in cp:
|
||||
cp["presence_penalty"] = 0
|
||||
|
||||
if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
|
||||
raise ValueError("presence_penalty must be a float between -2 and 2")
|
||||
|
||||
# if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
|
||||
# raise ValueError("presence_penalty must be a float between -2 and 2")
|
||||
#
|
||||
# presence_penalty
|
||||
if 'frequency_penalty' not in cp:
|
||||
cp["frequency_penalty"] = 0
|
||||
|
||||
if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
|
||||
raise ValueError("frequency_penalty must be a float between -2 and 2")
|
||||
# if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
|
||||
# raise ValueError("frequency_penalty must be a float between -2 and 2")
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_cp = {
|
||||
@@ -112,7 +81,7 @@ class AppModelConfigService:
|
||||
return filtered_cp
|
||||
|
||||
@staticmethod
|
||||
def validate_configuration(account: Account, config: dict, mode: str) -> dict:
|
||||
def validate_configuration(tenant_id: str, account: Account, config: dict) -> dict:
|
||||
# opening_statement
|
||||
if 'opening_statement' not in config or not config["opening_statement"]:
|
||||
config["opening_statement"] = ""
|
||||
@@ -211,14 +180,21 @@ class AppModelConfigService:
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# model.provider
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] not in MODEL_PROVIDERS:
|
||||
raise ValueError(f"model.provider is required and must be in {str(MODEL_PROVIDERS)}")
|
||||
model_provider_names = ModelProviderFactory.get_provider_names()
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
# model.name
|
||||
if 'name' not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
if config["model"]["name"] not in MODELS_BY_APP_MODE[mode]:
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, config["model"]["provider"])
|
||||
if not model_provider:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
model_list = model_provider.get_supported_model_list(ModelType.TEXT_GENERATION)
|
||||
model_ids = [m['id'] for m in model_list]
|
||||
if config["model"]["name"] not in model_ids:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
# model.completion_params
|
||||
|
Reference in New Issue
Block a user