refactor: remove unused codes, move core/agent module into dataset retrieval feature (#2614)
This commit is contained in:
@@ -4,7 +4,7 @@ from langchain.tools import BaseTool
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
|
||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.tool import Tool
|
||||
@@ -15,12 +15,12 @@ class DatasetRetrieverTool(Tool):
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_tools(tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity,
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler
|
||||
) -> list['DatasetRetrieverTool']:
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity,
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler
|
||||
) -> list['DatasetRetrieverTool']:
|
||||
"""
|
||||
get dataset tool
|
||||
"""
|
||||
@@ -46,7 +46,7 @@ class DatasetRetrieverTool(Tool):
|
||||
)
|
||||
# restore retrieve strategy
|
||||
retrieve_config.retrieve_strategy = original_retriever_mode
|
||||
|
||||
|
||||
# convert langchain tools to Tools
|
||||
tools = []
|
||||
for langchain_tool in langchain_tools:
|
||||
@@ -60,7 +60,7 @@ class DatasetRetrieverTool(Tool):
|
||||
llm=langchain_tool.description),
|
||||
runtime=DatasetRetrieverTool.Runtime()
|
||||
)
|
||||
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
@@ -68,13 +68,13 @@ class DatasetRetrieverTool(Tool):
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(name='query',
|
||||
label=I18nObject(en_US='', zh_Hans=''),
|
||||
human_description=I18nObject(en_US='', zh_Hans=''),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description='Query for the dataset to be used to retrieve the dataset.',
|
||||
required=True,
|
||||
default=''),
|
||||
label=I18nObject(en_US='', zh_Hans=''),
|
||||
human_description=I18nObject(en_US='', zh_Hans=''),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description='Query for the dataset to be used to retrieve the dataset.',
|
||||
required=True,
|
||||
default=''),
|
||||
]
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
@@ -84,7 +84,7 @@ class DatasetRetrieverTool(Tool):
|
||||
query = tool_parameters.get('query', None)
|
||||
if not query:
|
||||
return self.create_text_message(text='please input query')
|
||||
|
||||
|
||||
# invoke dataset retriever tool
|
||||
result = self.langchain_tool._run(query=query)
|
||||
|
||||
@@ -94,4 +94,4 @@ class DatasetRetrieverTool(Tool):
|
||||
"""
|
||||
validate the credentials for dataset retriever tool
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
@@ -7,23 +7,14 @@ import subprocess
|
||||
import tempfile
|
||||
import unicodedata
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString
|
||||
from langchain.chains import RefineDocumentsChain
|
||||
from langchain.chains.summarize import refine_prompts
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.tools.base import BaseTool
|
||||
from newspaper import Article
|
||||
from pydantic import BaseModel, Field
|
||||
from regex import regex
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.rag.extractor import extract_processor
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
@@ -36,106 +27,6 @@ TEXT:
|
||||
"""
|
||||
|
||||
|
||||
class WebReaderToolInput(BaseModel):
|
||||
url: str = Field(..., description="URL of the website to read")
|
||||
summary: bool = Field(
|
||||
default=False,
|
||||
description="When the user's question requires extracting the summarizing content of the webpage, "
|
||||
"set it to true."
|
||||
)
|
||||
cursor: int = Field(
|
||||
default=0,
|
||||
description="Start reading from this character."
|
||||
"Use when the first response was truncated"
|
||||
"and you want to continue reading the page."
|
||||
"The value cannot exceed 24000.",
|
||||
)
|
||||
|
||||
|
||||
class WebReaderTool(BaseTool):
|
||||
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
|
||||
|
||||
name: str = "web_reader"
|
||||
args_schema: type[BaseModel] = WebReaderToolInput
|
||||
description: str = "use this to read a website. " \
|
||||
"If you can answer the question based on the information provided, " \
|
||||
"there is no need to use."
|
||||
page_contents: str = None
|
||||
url: str = None
|
||||
max_chunk_length: int = 4000
|
||||
summary_chunk_tokens: int = 4000
|
||||
summary_chunk_overlap: int = 0
|
||||
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
|
||||
continue_reading: bool = True
|
||||
model_config: ModelConfigEntity
|
||||
model_parameters: dict[str, Any]
|
||||
|
||||
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
|
||||
try:
|
||||
if not self.page_contents or self.url != url:
|
||||
page_contents = get_url(url)
|
||||
self.page_contents = page_contents
|
||||
self.url = url
|
||||
else:
|
||||
page_contents = self.page_contents
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
|
||||
if summary:
|
||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=self.summary_chunk_tokens,
|
||||
chunk_overlap=self.summary_chunk_overlap,
|
||||
separators=self.summary_separators
|
||||
)
|
||||
|
||||
texts = character_splitter.split_text(page_contents)
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
|
||||
if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
|
||||
return "No content found."
|
||||
|
||||
# only use first 5 docs
|
||||
if len(docs) > 5:
|
||||
docs = docs[:5]
|
||||
|
||||
chain = self.get_summary_chain()
|
||||
try:
|
||||
page_contents = chain.run(docs)
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
else:
|
||||
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
|
||||
|
||||
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
|
||||
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
|
||||
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
|
||||
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
|
||||
|
||||
return page_contents
|
||||
|
||||
async def _arun(self, url: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_summary_chain(self) -> RefineDocumentsChain:
|
||||
initial_chain = LLMChain(
|
||||
model_config=self.model_config,
|
||||
prompt=refine_prompts.PROMPT,
|
||||
parameters=self.model_parameters
|
||||
)
|
||||
refine_chain = LLMChain(
|
||||
model_config=self.model_config,
|
||||
prompt=refine_prompts.REFINE_PROMPT,
|
||||
parameters=self.model_parameters
|
||||
)
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_chain,
|
||||
refine_llm_chain=refine_chain,
|
||||
document_variable_name="text",
|
||||
initial_response_name="existing_answer",
|
||||
callbacks=self.callbacks
|
||||
)
|
||||
|
||||
|
||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||
return text[cursor: cursor + max_length]
|
||||
|
Reference in New Issue
Block a user