feat: remove llm client use (#1316)

This commit is contained in:
takatost
2023-10-12 03:02:53 +08:00
committed by GitHub
parent c007dbdc13
commit cbf095465c
14 changed files with 434 additions and 353 deletions

View File

@@ -11,8 +11,8 @@ from typing import Type
import requests
from bs4 import BeautifulSoup, NavigableString, Comment, CData
from langchain.base_language import BaseLanguageModel
from langchain.chains.summarize import load_summarize_chain
from langchain.chains import RefineDocumentsChain
from langchain.chains.summarize import refine_prompts
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool
@@ -20,8 +20,10 @@ from newspaper import Article
from pydantic import BaseModel, Field
from regex import regex
from core.chain.llm_chain import LLMChain
from core.data_loader import file_extractor
from core.data_loader.file_extractor import FileExtractor
from core.model_providers.models.llm.base import BaseLLM
FULL_TEMPLATE = """
TITLE: {title}
@@ -65,7 +67,7 @@ class WebReaderTool(BaseTool):
summary_chunk_overlap: int = 0
summary_separators: list[str] = ["\n\n", "", ".", " ", ""]
continue_reading: bool = True
llm: BaseLanguageModel = None
model_instance: BaseLLM = None
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
try:
@@ -78,7 +80,7 @@ class WebReaderTool(BaseTool):
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
if summary and self.llm:
if summary and self.model_instance:
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=self.summary_chunk_tokens,
chunk_overlap=self.summary_chunk_overlap,
@@ -95,10 +97,9 @@ class WebReaderTool(BaseTool):
if len(docs) > 5:
docs = docs[:5]
chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks)
chain = self.get_summary_chain()
try:
page_contents = chain.run(docs)
# todo use cache
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
else:
@@ -114,6 +115,23 @@ class WebReaderTool(BaseTool):
async def _arun(self, url: str) -> str:
raise NotImplementedError
def get_summary_chain(self) -> RefineDocumentsChain:
initial_chain = LLMChain(
model_instance=self.model_instance,
prompt=refine_prompts.PROMPT
)
refine_chain = LLMChain(
model_instance=self.model_instance,
prompt=refine_prompts.REFINE_PROMPT
)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name="text",
initial_response_name="existing_answer",
callbacks=self.callbacks
)
def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""