feat: upgrade langchain (#430)

Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
John Wang
2023-06-25 16:49:14 +08:00
committed by GitHub
parent 1dee5de9b4
commit 3241e4015b
91 changed files with 2703 additions and 3153 deletions

View File

@@ -1,7 +1,6 @@
from typing import Union, Optional
from typing import Union, Optional, List
from langchain.callbacks import CallbackManager
from langchain.llms.fake import FakeListLLM
from langchain.callbacks.base import BaseCallbackHandler
from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError
@@ -32,12 +31,11 @@ class LLMBuilder:
"""
@classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]:
if model_name == 'fake':
return FakeListLLM(responses=[])
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
provider = cls.get_default_provider(tenant_id)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
if provider == 'openai':
@@ -52,16 +50,21 @@ class LLMBuilder:
else:
raise ValueError(f"model name {model_name} is not supported.")
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
model_kwargs = {
'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0),
}
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
return llm_cls(
model_name=model_name,
temperature=kwargs.get('temperature', 0),
max_tokens=kwargs.get('max_tokens', 256),
top_p=kwargs.get('top_p', 1),
frequency_penalty=kwargs.get('frequency_penalty', 0),
presence_penalty=kwargs.get('presence_penalty', 0),
callback_manager=kwargs.get('callback_manager', None),
**model_extras_kwargs,
callbacks=kwargs.get('callbacks', None),
streaming=kwargs.get('streaming', False),
# request_timeout=None
**model_credentials
@@ -69,7 +72,7 @@ class LLMBuilder:
@classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
model_name = model.get("name")
completion_params = model.get("completion_params", {})
@@ -82,7 +85,7 @@ class LLMBuilder:
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1),
streaming=streaming,
callback_manager=callback_manager
callbacks=callbacks
)
@classmethod