feat: remove llm client use (#1316)
This commit is contained in:
@@ -11,8 +11,8 @@ from typing import Type
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup, NavigableString, Comment, CData
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.summarize import load_summarize_chain
|
||||
from langchain.chains import RefineDocumentsChain
|
||||
from langchain.chains.summarize import refine_prompts
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.tools.base import BaseTool
|
||||
@@ -20,8 +20,10 @@ from newspaper import Article
|
||||
from pydantic import BaseModel, Field
|
||||
from regex import regex
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.data_loader import file_extractor
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
@@ -65,7 +67,7 @@ class WebReaderTool(BaseTool):
|
||||
summary_chunk_overlap: int = 0
|
||||
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
|
||||
continue_reading: bool = True
|
||||
llm: BaseLanguageModel = None
|
||||
model_instance: BaseLLM = None
|
||||
|
||||
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
|
||||
try:
|
||||
@@ -78,7 +80,7 @@ class WebReaderTool(BaseTool):
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
|
||||
if summary and self.llm:
|
||||
if summary and self.model_instance:
|
||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=self.summary_chunk_tokens,
|
||||
chunk_overlap=self.summary_chunk_overlap,
|
||||
@@ -95,10 +97,9 @@ class WebReaderTool(BaseTool):
|
||||
if len(docs) > 5:
|
||||
docs = docs[:5]
|
||||
|
||||
chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks)
|
||||
chain = self.get_summary_chain()
|
||||
try:
|
||||
page_contents = chain.run(docs)
|
||||
# todo use cache
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
else:
|
||||
@@ -114,6 +115,23 @@ class WebReaderTool(BaseTool):
|
||||
async def _arun(self, url: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_summary_chain(self) -> RefineDocumentsChain:
|
||||
initial_chain = LLMChain(
|
||||
model_instance=self.model_instance,
|
||||
prompt=refine_prompts.PROMPT
|
||||
)
|
||||
refine_chain = LLMChain(
|
||||
model_instance=self.model_instance,
|
||||
prompt=refine_prompts.REFINE_PROMPT
|
||||
)
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_chain,
|
||||
refine_llm_chain=refine_chain,
|
||||
document_variable_name="text",
|
||||
initial_response_name="existing_answer",
|
||||
callbacks=self.callbacks
|
||||
)
|
||||
|
||||
|
||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||
|
Reference in New Issue
Block a user