feat: upgrade langchain (#430)
Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from typing import Union, Optional
|
||||
from typing import Union, Optional, List
|
||||
|
||||
from langchain.callbacks import CallbackManager
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
@@ -32,12 +31,11 @@ class LLMBuilder:
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]:
|
||||
if model_name == 'fake':
|
||||
return FakeListLLM(responses=[])
|
||||
|
||||
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
provider = cls.get_default_provider(tenant_id)
|
||||
|
||||
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
||||
|
||||
mode = cls.get_mode_by_model(model_name)
|
||||
if mode == 'chat':
|
||||
if provider == 'openai':
|
||||
@@ -52,16 +50,21 @@ class LLMBuilder:
|
||||
else:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
||||
|
||||
model_kwargs = {
|
||||
'top_p': kwargs.get('top_p', 1),
|
||||
'frequency_penalty': kwargs.get('frequency_penalty', 0),
|
||||
'presence_penalty': kwargs.get('presence_penalty', 0),
|
||||
}
|
||||
|
||||
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
|
||||
|
||||
return llm_cls(
|
||||
model_name=model_name,
|
||||
temperature=kwargs.get('temperature', 0),
|
||||
max_tokens=kwargs.get('max_tokens', 256),
|
||||
top_p=kwargs.get('top_p', 1),
|
||||
frequency_penalty=kwargs.get('frequency_penalty', 0),
|
||||
presence_penalty=kwargs.get('presence_penalty', 0),
|
||||
callback_manager=kwargs.get('callback_manager', None),
|
||||
**model_extras_kwargs,
|
||||
callbacks=kwargs.get('callbacks', None),
|
||||
streaming=kwargs.get('streaming', False),
|
||||
# request_timeout=None
|
||||
**model_credentials
|
||||
@@ -69,7 +72,7 @@ class LLMBuilder:
|
||||
|
||||
@classmethod
|
||||
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
|
||||
callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
model_name = model.get("name")
|
||||
completion_params = model.get("completion_params", {})
|
||||
|
||||
@@ -82,7 +85,7 @@ class LLMBuilder:
|
||||
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
|
||||
presence_penalty=completion_params.get('presence_penalty', 0.1),
|
||||
streaming=streaming,
|
||||
callback_manager=callback_manager
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
Reference in New Issue
Block a user