feat: universal chat in explore (#649)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
John Wang
2023-07-27 13:08:57 +08:00
committed by GitHub
parent 94b54b7ca9
commit 4fdb37771a
64 changed files with 3186 additions and 858 deletions

59
api/core/llm/fake.py Normal file
View File

@@ -0,0 +1,59 @@
import time
from typing import List, Optional, Any, Mapping
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration, BaseLanguageModel
class FakeLLM(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
streaming: bool = False
"""Whether to stream the results or not."""
response: str
origin_llm: Optional[BaseLanguageModel] = None
@property
def _llm_type(self) -> str:
return "fake-chat-model"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
return self.response
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"response": self.response}
def get_num_tokens(self, text: str) -> int:
return self.origin_llm.get_num_tokens(text) if self.origin_llm else 0
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
if self.streaming:
for token in output_str:
if run_manager:
run_manager.on_llm_new_token(token)
time.sleep(0.01)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
llm_output = {"token_usage": {
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0,
}}
return ChatResult(generations=[generation], llm_output=llm_output)

View File

@@ -10,6 +10,9 @@ from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
return []
@@ -50,9 +53,10 @@ class AzureProvider(BaseProvider):
"""
config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure'
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
if model_id == 'text-embedding-ada-002':
config['deployment'] = model_id.replace('.', '') if model_id else None
config['chunk_size'] = 1
config['chunk_size'] = 16
else:
config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config
@@ -69,7 +73,7 @@ class AzureProvider(BaseProvider):
except:
config = {
'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': '',
'openai_api_key': ''
}
@@ -78,7 +82,7 @@ class AzureProvider(BaseProvider):
if not config.get('openai_api_key'):
config = {
'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': '',
'openai_api_key': ''
}
@@ -100,7 +104,7 @@ class AzureProvider(BaseProvider):
raise ValueError('Config must be a object.')
if 'openai_api_version' not in config:
config['openai_api_version'] = '2023-03-15-preview'
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
self.check_embedding_model(credentials=config)
except ValidateFailedError as e:
@@ -119,7 +123,7 @@ class AzureProvider(BaseProvider):
"""
return json.dumps({
'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': config['openai_api_base'],
'openai_api_key': self.encrypt_token(config['openai_api_key'])
})

View File

@@ -1,7 +1,8 @@
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, LLMResult
from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun
from langchain.chat_models.openai import _convert_dict_to_message
from langchain.schema import BaseMessage, LLMResult, ChatResult, ChatGeneration
from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict, Any, Tuple, Union
from pydantic import root_validator
@@ -9,6 +10,11 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureChatOpenAI(AzureChatOpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
@@ -71,3 +77,43 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
params['model_kwargs'] = model_kwargs
return params
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"
params["stream"] = True
function_call: Optional[dict] = None
for stream_resp in self.completion_with_retry(
messages=message_dicts, **params
):
if len(stream_resp["choices"]) > 0:
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content") or ""
inner_completion += token
_function_call = stream_resp["choices"][0]["delta"].get("function_call")
if _function_call:
if function_call is None:
function_call = _function_call
else:
function_call["arguments"] += _function_call["arguments"]
if run_manager:
run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
"function_call": function_call,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)

View File

@@ -1,7 +1,7 @@
from langchain.callbacks.manager import Callbacks
from langchain.llms import AzureOpenAI
from langchain.schema import LLMResult
from typing import Optional, List, Dict, Mapping, Any
from typing import Optional, List, Dict, Mapping, Any, Union, Tuple
from pydantic import root_validator
@@ -11,6 +11,10 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureOpenAI(AzureOpenAI):
openai_api_type: str = "azure"
openai_api_version: str = ""
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:

View File

@@ -1,8 +1,10 @@
from typing import List, Optional, Any, Dict
from httpx import Timeout
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, LLMResult
from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
from pydantic import root_validator
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
@@ -12,6 +14,14 @@ class StreamableChatAnthropic(ChatAnthropic):
Wrapper around Anthropic's large language model.
"""
default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0)
@root_validator()
def prepare_params(cls, values: Dict) -> Dict:
values['model_name'] = values.get('model')
values['max_tokens'] = values.get('max_tokens_to_sample')
return values
@handle_anthropic_exceptions
def generate(
self,
@@ -37,3 +47,16 @@ class StreamableChatAnthropic(ChatAnthropic):
del params['presence_penalty']
return params
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"<admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text

View File

@@ -3,7 +3,7 @@ import os
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, LLMResult
from langchain.chat_models import ChatOpenAI
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict, Any, Union, Tuple
from pydantic import root_validator
@@ -11,6 +11,10 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableChatOpenAI(ChatOpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:

View File

@@ -2,7 +2,7 @@ import os
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from typing import Optional, List, Dict, Any, Mapping
from typing import Optional, List, Dict, Any, Mapping, Union, Tuple
from langchain import OpenAI
from pydantic import root_validator
@@ -10,6 +10,10 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableOpenAI(OpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict: