feat: server multi models support (#799)

This commit is contained in:
takatost
2023-08-12 00:57:00 +08:00
committed by GitHub
parent d8b712b325
commit 5fa2161b05
213 changed files with 10556 additions and 2579 deletions

View File

@@ -2,42 +2,11 @@ import re
import uuid
from core.agent.agent_executor import PlanningStrategy
from core.constant import llm_constant
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType
from models.account import Account
from services.dataset_service import DatasetService
from core.llm.llm_builder import LLMBuilder
MODEL_PROVIDERS = [
'openai',
'anthropic',
]
MODELS_BY_APP_MODE = {
'chat': [
'claude-instant-1',
'claude-2',
'gpt-4',
'gpt-4-32k',
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k',
],
'completion': [
'claude-instant-1',
'claude-2',
'gpt-4',
'gpt-4-32k',
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k',
'text-davinci-003',
]
}
SUPPORT_AGENT_MODELS = [
"gpt-4",
"gpt-4-32k",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
]
SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
@@ -65,40 +34,40 @@ class AppModelConfigService:
# max_tokens
if 'max_tokens' not in cp:
cp["max_tokens"] = 512
if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
llm_constant.max_context_token_length[model_name]:
raise ValueError(
"max_tokens must be an integer greater than 0 "
"and not exceeding the maximum value of the corresponding model")
#
# if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
# llm_constant.max_context_token_length[model_name]:
# raise ValueError(
# "max_tokens must be an integer greater than 0 "
# "and not exceeding the maximum value of the corresponding model")
#
# temperature
if 'temperature' not in cp:
cp["temperature"] = 1
if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
raise ValueError("temperature must be a float between 0 and 2")
#
# if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
# raise ValueError("temperature must be a float between 0 and 2")
#
# top_p
if 'top_p' not in cp:
cp["top_p"] = 1
if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
raise ValueError("top_p must be a float between 0 and 2")
# if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
# raise ValueError("top_p must be a float between 0 and 2")
#
# presence_penalty
if 'presence_penalty' not in cp:
cp["presence_penalty"] = 0
if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
raise ValueError("presence_penalty must be a float between -2 and 2")
# if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
# raise ValueError("presence_penalty must be a float between -2 and 2")
#
# presence_penalty
if 'frequency_penalty' not in cp:
cp["frequency_penalty"] = 0
if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
raise ValueError("frequency_penalty must be a float between -2 and 2")
# if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
# raise ValueError("frequency_penalty must be a float between -2 and 2")
# Filter out extra parameters
filtered_cp = {
@@ -112,7 +81,7 @@ class AppModelConfigService:
return filtered_cp
@staticmethod
def validate_configuration(account: Account, config: dict, mode: str) -> dict:
def validate_configuration(tenant_id: str, account: Account, config: dict) -> dict:
# opening_statement
if 'opening_statement' not in config or not config["opening_statement"]:
config["opening_statement"] = ""
@@ -211,14 +180,21 @@ class AppModelConfigService:
raise ValueError("model must be of object type")
# model.provider
if 'provider' not in config["model"] or config["model"]["provider"] not in MODEL_PROVIDERS:
raise ValueError(f"model.provider is required and must be in {str(MODEL_PROVIDERS)}")
model_provider_names = ModelProviderFactory.get_provider_names()
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
# model.name
if 'name' not in config["model"]:
raise ValueError("model.name is required")
if config["model"]["name"] not in MODELS_BY_APP_MODE[mode]:
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, config["model"]["provider"])
if not model_provider:
raise ValueError("model.name must be in the specified model list")
model_list = model_provider.get_supported_model_list(ModelType.TEXT_GENERATION)
model_ids = [m['id'] for m in model_list]
if config["model"]["name"] not in model_ids:
raise ValueError("model.name must be in the specified model list")
# model.completion_params