feat: add openllm support (#928)
This commit is contained in:
@@ -60,6 +60,9 @@ class ModelProviderFactory:
|
||||
elif provider_name == 'xinference':
|
||||
from core.model_providers.providers.xinference_provider import XinferenceProvider
|
||||
return XinferenceProvider
|
||||
elif provider_name == 'openllm':
|
||||
from core.model_providers.providers.openllm_provider import OpenLLMProvider
|
||||
return OpenLLMProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
60
api/core/model_providers/models/llm/openllm_model.py
Normal file
60
api/core/model_providers/models/llm/openllm_model.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import OpenLLM
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class OpenLLMModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
|
||||
client = OpenLLM(
|
||||
server_url=self.credentials.get('server_url'),
|
||||
callbacks=self.callbacks,
|
||||
**self.provider_model_kwargs
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
pass
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"OpenLLM: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
137
api/core/model_providers/providers/openllm_provider.py
Normal file
137
api/core/model_providers/providers/openllm_provider.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from langchain.llms import OpenLLM
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class OpenLLMProvider(BaseModelProvider):
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'openllm'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = OpenLLMModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=128),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if 'server_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('OpenLLM Server URL must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'server_url': credentials['server_url']
|
||||
}
|
||||
|
||||
llm = OpenLLM(
|
||||
max_tokens=10,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
if self.provider.provider_type != ProviderType.CUSTOM.value:
|
||||
raise NotImplementedError
|
||||
|
||||
provider_model = self._get_provider_model(model_name, model_type)
|
||||
|
||||
if not provider_model.encrypted_config:
|
||||
return {
|
||||
'server_url': None
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
if credentials['server_url']:
|
||||
credentials['server_url'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['server_url']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
|
||||
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
return {}
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
return {}
|
@@ -9,5 +9,6 @@
|
||||
"chatglm",
|
||||
"replicate",
|
||||
"huggingface_hub",
|
||||
"xinference"
|
||||
"xinference",
|
||||
"openllm"
|
||||
]
|
7
api/core/model_providers/rules/openllm.json
Normal file
7
api/core/model_providers/rules/openllm.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
}
|
Reference in New Issue
Block a user