Feat: Q&A format segmentation support (#668)
Co-authored-by: jyong <718720800@qq.com> Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
@@ -39,7 +39,7 @@ class ExcelLoader(BaseLoader):
|
||||
row_dict = dict(zip(keys, list(map(str, row))))
|
||||
row_dict = {k: v for k, v in row_dict.items() if v}
|
||||
item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items())
|
||||
document = Document(page_content=item)
|
||||
document = Document(page_content=item, metadata={'source': self._file_path})
|
||||
data.append(document)
|
||||
|
||||
return data
|
||||
|
@@ -68,7 +68,7 @@ class DatesetDocumentStore:
|
||||
self, docs: Sequence[Document], allow_update: bool = True
|
||||
) -> None:
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document == self._document_id
|
||||
DocumentSegment.document_id == self._document_id
|
||||
).scalar()
|
||||
|
||||
if max_position is None:
|
||||
@@ -105,9 +105,14 @@ class DatesetDocumentStore:
|
||||
tokens=tokens,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
if 'answer' in doc.metadata and doc.metadata['answer']:
|
||||
segment_document.answer = doc.metadata.pop('answer', '')
|
||||
|
||||
db.session.add(segment_document)
|
||||
else:
|
||||
segment_document.content = doc.page_content
|
||||
if 'answer' in doc.metadata and doc.metadata['answer']:
|
||||
segment_document.answer = doc.metadata.pop('answer', '')
|
||||
segment_document.index_node_hash = doc.metadata['doc_hash']
|
||||
segment_document.word_count = len(doc.page_content)
|
||||
segment_document.tokens = tokens
|
||||
|
@@ -2,7 +2,7 @@ import logging
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import HumanMessage, OutputParserException, BaseMessage
|
||||
from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
@@ -12,8 +12,8 @@ from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorO
|
||||
|
||||
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
|
||||
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT
|
||||
|
||||
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
|
||||
GENERATOR_QA_PROMPT
|
||||
|
||||
# gpt-3.5-turbo works not well
|
||||
generate_base_model = 'text-davinci-003'
|
||||
@@ -31,7 +31,8 @@ class LLMGenerator:
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=50
|
||||
max_tokens=50,
|
||||
timeout=600
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
@@ -185,3 +186,27 @@ class LLMGenerator:
|
||||
}
|
||||
|
||||
return rule_config
|
||||
|
||||
@classmethod
|
||||
async def generate_qa_document(cls, llm: StreamableOpenAI, query):
|
||||
prompt = GENERATOR_QA_PROMPT
|
||||
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_qa_document_sync(cls, llm: StreamableOpenAI, query):
|
||||
prompt = GENERATOR_QA_PROMPT
|
||||
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
return answer.strip()
|
||||
|
@@ -205,6 +205,16 @@ class KeywordTableIndex(BaseIndex):
|
||||
document_segment.keywords = keywords
|
||||
db.session.commit()
|
||||
|
||||
def create_segment_keywords(self, node_id: str, keywords: List[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
self._update_segment_keywords(node_id, keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
class KeywordTableRetriever(BaseRetriever, BaseModel):
|
||||
index: KeywordTableIndex
|
||||
|
123
api/core/index/vector_index/test-embedding.py
Normal file
123
api/core/index/vector_index/test-embedding.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import numpy as np
|
||||
import sklearn.decomposition
|
||||
import pickle
|
||||
import time
|
||||
|
||||
|
||||
# Apply 'Algorithm 1' to the ada-002 embeddings to make them isotropic, taken from the paper:
|
||||
# ALL-BUT-THE-TOP: SIMPLE AND EFFECTIVE POST- PROCESSING FOR WORD REPRESENTATIONS
|
||||
# Jiaqi Mu, Pramod Viswanath
|
||||
|
||||
# This uses Principal Component Analysis (PCA) to 'evenly distribute' the embedding vectors (make them isotropic)
|
||||
# For more information on PCA, see https://jamesmccaffrey.wordpress.com/2021/07/16/computing-pca-using-numpy-without-scikit/
|
||||
|
||||
|
||||
# get the file pointer of the pickle containing the embeddings
|
||||
fp = open('/path/to/your/data/Embedding-Latest.pkl', 'rb')
|
||||
|
||||
|
||||
# the embedding data here is a dict consisting of key / value pairs
|
||||
# the key is the hash of the message (SHA3-256), the value is the embedding from ada-002 (array of dimension 1536)
|
||||
# the hash can be used to lookup the orignal text in a database
|
||||
E = pickle.load(fp) # load the data into memory
|
||||
|
||||
# seperate the keys (hashes) and values (embeddings) into seperate vectors
|
||||
K = list(E.keys()) # vector of all the hash values
|
||||
X = np.array(list(E.values())) # vector of all the embeddings, converted to numpy arrays
|
||||
|
||||
|
||||
# list the total number of embeddings
|
||||
# this can be truncated if there are too many embeddings to do PCA on
|
||||
print(f"Total number of embeddings: {len(X)}")
|
||||
|
||||
# get dimension of embeddings, used later
|
||||
Dim = len(X[0])
|
||||
|
||||
# flash out the first few embeddings
|
||||
print("First two embeddings are: ")
|
||||
print(X[0])
|
||||
print(f"First embedding length: {len(X[0])}")
|
||||
print(X[1])
|
||||
print(f"Second embedding length: {len(X[1])}")
|
||||
|
||||
|
||||
# compute the mean of all the embeddings, and flash the result
|
||||
mu = np.mean(X, axis=0) # same as mu in paper
|
||||
print(f"Mean embedding vector: {mu}")
|
||||
print(f"Mean embedding vector length: {len(mu)}")
|
||||
|
||||
|
||||
# subtract the mean vector from each embedding vector ... vectorized in numpy
|
||||
X_tilde = X - mu # same as v_tilde(w) in paper
|
||||
|
||||
|
||||
|
||||
# do the heavy lifting of extracting the principal components
|
||||
# note that this is a function of the embeddings you currently have here, and this set may grow over time
|
||||
# therefore the PCA basis vectors may change over time, and your final isotropic embeddings may drift over time
|
||||
# but the drift should stabilize after you have extracted enough embedding data to characterize the nature of the embedding engine
|
||||
print(f"Performing PCA on the normalized embeddings ...")
|
||||
pca = sklearn.decomposition.PCA() # new object
|
||||
TICK = time.time() # start timer
|
||||
pca.fit(X_tilde) # do the heavy lifting!
|
||||
TOCK = time.time() # end timer
|
||||
DELTA = TOCK - TICK
|
||||
|
||||
print(f"PCA finished in {DELTA} seconds ...")
|
||||
|
||||
# dimensional reduction stage (the only hyperparameter)
|
||||
# pick max dimension of PCA components to express embddings
|
||||
# in general this is some integer less than or equal to the dimension of your embeddings
|
||||
# it could be set as a high percentile, say 95th percentile of pca.explained_variance_ratio_
|
||||
# but just hardcoding a constant here
|
||||
D = 15 # hyperparameter on dimension (out of 1536 for ada-002), paper recommeds D = Dim/100
|
||||
|
||||
|
||||
# form the set of v_prime(w), which is the final embedding
|
||||
# this could be vectorized in numpy to speed it up, but coding it directly here in a double for-loop to avoid errors and to be transparent
|
||||
E_prime = dict() # output dict of the new embeddings
|
||||
N = len(X_tilde)
|
||||
N10 = round(N/10)
|
||||
U = pca.components_ # set of PCA basis vectors, sorted by most significant to least significant
|
||||
print(f"Shape of full set of PCA componenents {U.shape}")
|
||||
U = U[0:D,:] # take the top D dimensions (or take them all if D is the size of the embedding vector)
|
||||
print(f"Shape of downselected PCA componenents {U.shape}")
|
||||
for ii in range(N):
|
||||
v_tilde = X_tilde[ii]
|
||||
v = X[ii]
|
||||
v_projection = np.zeros(Dim) # start to build the projection
|
||||
# project the original embedding onto the PCA basis vectors, use only first D dimensions
|
||||
for jj in range(D):
|
||||
u_jj = U[jj,:] # vector
|
||||
v_jj = np.dot(u_jj,v) # scaler
|
||||
v_projection += v_jj*u_jj # vector
|
||||
v_prime = v_tilde - v_projection # final embedding vector
|
||||
v_prime = v_prime/np.linalg.norm(v_prime) # create unit vector
|
||||
E_prime[K[ii]] = v_prime
|
||||
|
||||
if (ii%N10 == 0) or (ii == N-1):
|
||||
print(f"Finished with {ii+1} embeddings out of {N} ({round(100*ii/N)}% done)")
|
||||
|
||||
|
||||
# save as new pickle
|
||||
print("Saving new pickle ...")
|
||||
embeddingName = '/path/to/your/data/Embedding-Latest-Isotropic.pkl'
|
||||
with open(embeddingName, 'wb') as f: # Python 3: open(..., 'wb')
|
||||
pickle.dump([E_prime,mu,U], f)
|
||||
print(embeddingName)
|
||||
|
||||
print("Done!")
|
||||
|
||||
# When working with live data with a new embedding from ada-002, be sure to tranform it first with this function before comparing it
|
||||
#
|
||||
def projectEmbedding(v,mu,U):
|
||||
v = np.array(v)
|
||||
v_tilde = v - mu
|
||||
v_projection = np.zeros(len(v)) # start to build the projection
|
||||
# project the original embedding onto the PCA basis vectors, use only first D dimensions
|
||||
for u in U:
|
||||
v_jj = np.dot(u,v) # scaler
|
||||
v_projection += v_jj*u # vector
|
||||
v_prime = v_tilde - v_projection # final embedding vector
|
||||
v_prime = v_prime/np.linalg.norm(v_prime) # create unit vector
|
||||
return v_prime
|
@@ -1,13 +1,20 @@
|
||||
import asyncio
|
||||
import concurrent
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from multiprocessing import Process
|
||||
from typing import Optional, List, cast
|
||||
|
||||
from flask import current_app
|
||||
import openai
|
||||
from billiard.pool import Pool
|
||||
from flask import current_app, Flask
|
||||
from flask_login import current_user
|
||||
from gevent.threadpool import ThreadPoolExecutor
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
@@ -16,11 +23,13 @@ from core.data_loader.file_extractor import FileExtractor
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.docstore.dataset_docstore import DatesetDocumentStore
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.index.index import IndexBuilder
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
|
||||
from core.llm.token_calculator import TokenCalculator
|
||||
from extensions.ext_database import db
|
||||
@@ -70,7 +79,13 @@ class IndexingRunner:
|
||||
dataset_document=dataset_document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
|
||||
# new_documents = []
|
||||
# for document in documents:
|
||||
# response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
|
||||
# document_qa_list = self.format_split_text(response)
|
||||
# for result in document_qa_list:
|
||||
# document = Document(page_content=result['question'], metadata={'source': result['answer']})
|
||||
# new_documents.append(document)
|
||||
# build index
|
||||
self._build_index(
|
||||
dataset=dataset,
|
||||
@@ -91,6 +106,22 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
|
||||
matches = re.findall(regex, text, re.MULTILINE)
|
||||
|
||||
result = []
|
||||
for match in matches:
|
||||
q = match[0]
|
||||
a = match[1]
|
||||
if q and a:
|
||||
result.append({
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def run_in_splitting_status(self, dataset_document: DatasetDocument):
|
||||
"""Run the indexing process when the index_status is splitting."""
|
||||
try:
|
||||
@@ -205,7 +236,8 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
|
||||
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict,
|
||||
doc_form: str = None) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
@@ -225,7 +257,7 @@ class IndexingRunner:
|
||||
splitter = self._get_splitter(processing_rule)
|
||||
|
||||
# split to documents
|
||||
documents = self._split_to_documents(
|
||||
documents = self._split_to_documents_for_estimate(
|
||||
text_docs=text_docs,
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule
|
||||
@@ -237,7 +269,25 @@ class IndexingRunner:
|
||||
|
||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
|
||||
self.filter_string(document.page_content))
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=2000
|
||||
)
|
||||
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
|
||||
"currency": TokenCalculator.get_currency(self.embedding_model_name),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
@@ -246,7 +296,7 @@ class IndexingRunner:
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict) -> dict:
|
||||
def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
@@ -285,7 +335,7 @@ class IndexingRunner:
|
||||
splitter = self._get_splitter(processing_rule)
|
||||
|
||||
# split to documents
|
||||
documents = self._split_to_documents(
|
||||
documents = self._split_to_documents_for_estimate(
|
||||
text_docs=documents,
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule
|
||||
@@ -296,7 +346,25 @@ class IndexingRunner:
|
||||
preview_texts.append(document.page_content)
|
||||
|
||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=2000
|
||||
)
|
||||
response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
|
||||
"currency": TokenCalculator.get_currency(self.embedding_model_name),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
@@ -391,7 +459,9 @@ class IndexingRunner:
|
||||
documents = self._split_to_documents(
|
||||
text_docs=text_docs,
|
||||
splitter=splitter,
|
||||
processing_rule=processing_rule
|
||||
processing_rule=processing_rule,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_form=dataset_document.doc_form
|
||||
)
|
||||
|
||||
# save node to document segment
|
||||
@@ -428,7 +498,64 @@ class IndexingRunner:
|
||||
return documents
|
||||
|
||||
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule) -> List[Document]:
|
||||
processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
all_documents = []
|
||||
for text_doc in text_docs:
|
||||
# document clean
|
||||
document_text = self._document_clean(text_doc.page_content, processing_rule)
|
||||
text_doc.page_content = document_text
|
||||
|
||||
# parse document to nodes
|
||||
documents = splitter.split_documents([text_doc])
|
||||
split_documents = []
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
max_tokens=2000
|
||||
)
|
||||
self.format_document(llm, documents, split_documents, document_form)
|
||||
all_documents.extend(split_documents)
|
||||
|
||||
return all_documents
|
||||
|
||||
def format_document(self, llm: StreamableOpenAI, documents: List[Document], split_documents: List, document_form: str):
|
||||
for document_node in documents:
|
||||
format_documents = []
|
||||
if document_node.page_content is None or not document_node.page_content.strip():
|
||||
return format_documents
|
||||
if document_form == 'text_model':
|
||||
# text model document
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
|
||||
format_documents.append(document_node)
|
||||
elif document_form == 'qa_model':
|
||||
try:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
qa_documents = []
|
||||
for result in document_qa_list:
|
||||
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(result['question'])
|
||||
qa_document.metadata['answer'] = result['answer']
|
||||
qa_document.metadata['doc_id'] = doc_id
|
||||
qa_document.metadata['doc_hash'] = hash
|
||||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception:
|
||||
continue
|
||||
split_documents.extend(format_documents)
|
||||
|
||||
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule) -> List[Document]:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
@@ -445,7 +572,6 @@ class IndexingRunner:
|
||||
for document in documents:
|
||||
if document.page_content is None or not document.page_content.strip():
|
||||
continue
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document.page_content)
|
||||
|
||||
@@ -487,6 +613,23 @@ class IndexingRunner:
|
||||
|
||||
return text
|
||||
|
||||
def format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式
|
||||
matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果
|
||||
|
||||
result = [] # 存储最终的结果
|
||||
for match in matches:
|
||||
q = match[0]
|
||||
a = match[1]
|
||||
if q and a:
|
||||
# 如果Q和A都存在,就将其添加到结果中
|
||||
result.append({
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
|
||||
"""
|
||||
Build the index for the document.
|
||||
|
@@ -43,6 +43,16 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"[\"question1\",\"question2\",\"question3\"]\n"
|
||||
)
|
||||
|
||||
GENERATOR_QA_PROMPT = (
|
||||
"Please respond according to the language of the user's input text. If the text is in language [A], you must also reply in language [A].\n"
|
||||
'Step 1: Understand and summarize the main content of this text.\n'
|
||||
'Step 2: What key information or concepts are mentioned in this text?\n'
|
||||
'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
|
||||
'Step 4: Generate 20 questions and answers based on these key information and concepts.'
|
||||
'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
|
||||
"Answer in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
|
||||
)
|
||||
|
||||
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
|
||||
the model prompt that best suits the input.
|
||||
You will be provided with the prompt, variables, and an opening statement.
|
||||
|
102
api/core/tool/dataset_index_tool.py
Normal file
102
api/core/tool/dataset_index_tool.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
class DatasetTool(BaseTool):
|
||||
"""Tool for querying a Dataset."""
|
||||
|
||||
dataset: Dataset
|
||||
k: int = 2
|
||||
|
||||
def _run(self, tool_input: str) -> str:
|
||||
if self.dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
kw_table_index = KeywordTableIndex(
|
||||
dataset=self.dataset,
|
||||
config=KeywordTableConfig(
|
||||
max_keywords_per_chunk=5
|
||||
)
|
||||
)
|
||||
|
||||
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=self.dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||
**model_credentials
|
||||
))
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=self.dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
documents = vector_index.search(
|
||||
tool_input,
|
||||
search_type='similarity',
|
||||
search_kwargs={
|
||||
'k': self.k
|
||||
}
|
||||
)
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in documents]
|
||||
segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||
).all()
|
||||
|
||||
if segments:
|
||||
for segment in segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(segment.answer)
|
||||
else:
|
||||
document_context_list.append(segment.content)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=self.dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||
**model_credentials
|
||||
))
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=self.dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
documents = await vector_index.asearch(
|
||||
tool_input,
|
||||
search_type='similarity',
|
||||
search_kwargs={
|
||||
'k': 10
|
||||
}
|
||||
)
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
|
||||
hit_callback.on_tool_end(documents)
|
||||
return str("\n".join([document.page_content for document in documents]))
|
@@ -12,7 +12,7 @@ from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
class DatasetRetrieverToolInput(BaseModel):
|
||||
@@ -69,6 +69,7 @@ class DatasetRetrieverTool(BaseTool):
|
||||
)
|
||||
|
||||
documents = kw_table_index.search(query, search_kwargs={'k': self.k})
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=dataset.tenant_id,
|
||||
@@ -99,8 +100,22 @@ class DatasetRetrieverTool(BaseTool):
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in documents]
|
||||
segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||
).all()
|
||||
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
if segments:
|
||||
for segment in segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f'question:{segment.content} \nanswer:{segment.answer}')
|
||||
else:
|
||||
document_context_list.append(segment.content)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
Reference in New Issue
Block a user