refactor: remove unused codes, move core/agent module into dataset retrieval feature (#2614)

This commit is contained in:
takatost
2024-02-28 23:32:47 +08:00
committed by GitHub
parent d44b05a9e5
commit dd961985f0
29 changed files with 41 additions and 2016 deletions

View File

@@ -4,7 +4,7 @@ from langchain.tools import BaseTool
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
from core.tools.tool.tool import Tool
@@ -15,12 +15,12 @@ class DatasetRetrieverTool(Tool):
@staticmethod
def get_dataset_tools(tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler
) -> list['DatasetRetrieverTool']:
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler
) -> list['DatasetRetrieverTool']:
"""
get dataset tool
"""
@@ -46,7 +46,7 @@ class DatasetRetrieverTool(Tool):
)
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
# convert langchain tools to Tools
tools = []
for langchain_tool in langchain_tools:
@@ -60,7 +60,7 @@ class DatasetRetrieverTool(Tool):
llm=langchain_tool.description),
runtime=DatasetRetrieverTool.Runtime()
)
tools.append(tool)
return tools
@@ -68,13 +68,13 @@ class DatasetRetrieverTool(Tool):
def get_runtime_parameters(self) -> list[ToolParameter]:
return [
ToolParameter(name='query',
label=I18nObject(en_US='', zh_Hans=''),
human_description=I18nObject(en_US='', zh_Hans=''),
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
llm_description='Query for the dataset to be used to retrieve the dataset.',
required=True,
default=''),
label=I18nObject(en_US='', zh_Hans=''),
human_description=I18nObject(en_US='', zh_Hans=''),
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
llm_description='Query for the dataset to be used to retrieve the dataset.',
required=True,
default=''),
]
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
@@ -84,7 +84,7 @@ class DatasetRetrieverTool(Tool):
query = tool_parameters.get('query', None)
if not query:
return self.create_text_message(text='please input query')
# invoke dataset retriever tool
result = self.langchain_tool._run(query=query)
@@ -94,4 +94,4 @@ class DatasetRetrieverTool(Tool):
"""
validate the credentials for dataset retriever tool
"""
pass
pass

View File

@@ -7,23 +7,14 @@ import subprocess
import tempfile
import unicodedata
from contextlib import contextmanager
from typing import Any
import requests
from bs4 import BeautifulSoup, CData, Comment, NavigableString
from langchain.chains import RefineDocumentsChain
from langchain.chains.summarize import refine_prompts
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool
from newspaper import Article
from pydantic import BaseModel, Field
from regex import regex
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.models.document import Document
FULL_TEMPLATE = """
TITLE: {title}
@@ -36,106 +27,6 @@ TEXT:
"""
class WebReaderToolInput(BaseModel):
url: str = Field(..., description="URL of the website to read")
summary: bool = Field(
default=False,
description="When the user's question requires extracting the summarizing content of the webpage, "
"set it to true."
)
cursor: int = Field(
default=0,
description="Start reading from this character."
"Use when the first response was truncated"
"and you want to continue reading the page."
"The value cannot exceed 24000.",
)
class WebReaderTool(BaseTool):
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
name: str = "web_reader"
args_schema: type[BaseModel] = WebReaderToolInput
description: str = "use this to read a website. " \
"If you can answer the question based on the information provided, " \
"there is no need to use."
page_contents: str = None
url: str = None
max_chunk_length: int = 4000
summary_chunk_tokens: int = 4000
summary_chunk_overlap: int = 0
summary_separators: list[str] = ["\n\n", "", ".", " ", ""]
continue_reading: bool = True
model_config: ModelConfigEntity
model_parameters: dict[str, Any]
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
try:
if not self.page_contents or self.url != url:
page_contents = get_url(url)
self.page_contents = page_contents
self.url = url
else:
page_contents = self.page_contents
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
if summary:
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=self.summary_chunk_tokens,
chunk_overlap=self.summary_chunk_overlap,
separators=self.summary_separators
)
texts = character_splitter.split_text(page_contents)
docs = [Document(page_content=t) for t in texts]
if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
return "No content found."
# only use first 5 docs
if len(docs) > 5:
docs = docs[:5]
chain = self.get_summary_chain()
try:
page_contents = chain.run(docs)
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
else:
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
return page_contents
async def _arun(self, url: str) -> str:
raise NotImplementedError
def get_summary_chain(self) -> RefineDocumentsChain:
initial_chain = LLMChain(
model_config=self.model_config,
prompt=refine_prompts.PROMPT,
parameters=self.model_parameters
)
refine_chain = LLMChain(
model_config=self.model_config,
prompt=refine_prompts.REFINE_PROMPT,
parameters=self.model_parameters
)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name="text",
initial_response_name="existing_answer",
callbacks=self.callbacks
)
def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return text[cursor: cursor + max_length]