Feat:dataset retiever resource (#1123)

Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
Jyong
2023-09-10 15:17:43 +08:00
committed by GitHub
parent e161c511af
commit 642842d61b
32 changed files with 442 additions and 33 deletions

View File

@@ -1,3 +1,4 @@
import json
from typing import Type
from flask import current_app
@@ -5,13 +6,14 @@ from langchain.tools import BaseTool
from pydantic import Field, BaseModel
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.conversation_message_task import ConversationMessageTask
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Dataset, DocumentSegment, Document
class DatasetRetrieverToolInput(BaseModel):
@@ -27,6 +29,10 @@ class DatasetRetrieverTool(BaseTool):
tenant_id: str
dataset_id: str
k: int = 3
conversation_message_task: ConversationMessageTask
return_resource: str
retriever_from: str
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
@@ -86,7 +92,7 @@ class DatasetRetrieverTool(BaseTool):
if self.k > 0:
documents = vector_index.search(
query,
search_type='similarity',
search_type='similarity_score_threshold',
search_kwargs={
'k': self.k
}
@@ -94,8 +100,12 @@ class DatasetRetrieverTool(BaseTool):
else:
documents = []
hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task)
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in documents]
segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
@@ -112,9 +122,43 @@ class DatasetRetrieverTool(BaseTool):
float('inf')))
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f'question:{segment.content} \nanswer:{segment.answer}')
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
else:
document_context_list.append(segment.content)
if self.return_resource:
context_list = []
resource_number = 1
for segment in sorted_segments:
context = {}
document = Document.query.filter(Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
source = {
'position': resource_number,
'dataset_id': dataset.id,
'dataset_name': dataset.name,
'document_id': document.id,
'document_name': document.name,
'data_source_type': document.data_source_type,
'segment_id': segment.id,
'retriever_from': self.retriever_from
}
if dataset.indexing_technique != "economy":
source['score'] = document_score_list.get(segment.index_node_id)
if self.retriever_from == 'dev':
source['hit_count'] = segment.hit_count
source['word_count'] = segment.word_count
source['segment_position'] = segment.position
source['index_node_hash'] = segment.index_node_hash
if segment.answer:
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
else:
source['content'] = segment.content
context_list.append(source)
resource_number += 1
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))