feat: return page number of pdf documents upon retrieval (#7749)

This commit is contained in:
Byeongjin Kang
2024-09-05 17:43:26 +09:00
committed by GitHub
parent bd0992275c
commit d489b8b3e0
3 changed files with 10 additions and 2 deletions

View File

@@ -30,7 +30,7 @@ class AbstractVectorFactory(ABC):
class Vector: class Vector:
def __init__(self, dataset: Dataset, attributes: list = None): def __init__(self, dataset: Dataset, attributes: list = None):
if attributes is None: if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash', 'page']
self._dataset = dataset self._dataset = dataset
self._embeddings = self._get_embeddings() self._embeddings = self._get_embeddings()
self._attributes = attributes self._attributes = attributes
@@ -107,6 +107,7 @@ class Vector:
def add_texts(self, documents: list[Document], **kwargs): def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get('duplicate_check', False): if kwargs.get('duplicate_check', False):
documents = self._filter_duplicate_texts(documents) documents = self._filter_duplicate_texts(documents)
embeddings = self._embeddings.embed_documents([document.page_content for document in documents]) embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
self._vector_processor.create( self._vector_processor.create(
texts=documents, texts=documents,

View File

@@ -173,9 +173,13 @@ class KnowledgeRetrievalNode(BaseNode):
context_list = [] context_list = []
if all_documents: if all_documents:
document_score_list = {} document_score_list = {}
page_number_list = {}
for item in all_documents: for item in all_documents:
if item.metadata.get('score'): if item.metadata.get('score'):
document_score_list[item.metadata['doc_id']] = item.metadata['score'] document_score_list[item.metadata['doc_id']] = item.metadata['score']
# both 'page' and 'score' are metadata fields
if item.metadata.get('page'):
page_number_list[item.metadata['doc_id']] = item.metadata['page']
index_node_ids = [document.metadata['doc_id'] for document in all_documents] index_node_ids = [document.metadata['doc_id'] for document in all_documents]
segments = DocumentSegment.query.filter( segments = DocumentSegment.query.filter(
@@ -199,9 +203,9 @@ class KnowledgeRetrievalNode(BaseNode):
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
).first() ).first()
resource_number = 1 resource_number = 1
if dataset and document: if dataset and document:
source = { source = {
'metadata': { 'metadata': {
'_source': 'knowledge', '_source': 'knowledge',
@@ -211,6 +215,7 @@ class KnowledgeRetrievalNode(BaseNode):
'document_id': document.id, 'document_id': document.id,
'document_name': document.name, 'document_name': document.name,
'document_data_source_type': document.data_source_type, 'document_data_source_type': document.data_source_type,
'page': page_number_list.get(segment.index_node_id, None),
'segment_id': segment.id, 'segment_id': segment.id,
'retriever_from': 'workflow', 'retriever_from': 'workflow',
'score': document_score_list.get(segment.index_node_id, None), 'score': document_score_list.get(segment.index_node_id, None),

View File

@@ -402,6 +402,7 @@ class LLMNode(BaseNode):
if ('metadata' in context_dict and '_source' in context_dict['metadata'] if ('metadata' in context_dict and '_source' in context_dict['metadata']
and context_dict['metadata']['_source'] == 'knowledge'): and context_dict['metadata']['_source'] == 'knowledge'):
metadata = context_dict.get('metadata', {}) metadata = context_dict.get('metadata', {})
source = { source = {
'position': metadata.get('position'), 'position': metadata.get('position'),
'dataset_id': metadata.get('dataset_id'), 'dataset_id': metadata.get('dataset_id'),
@@ -417,6 +418,7 @@ class LLMNode(BaseNode):
'segment_position': metadata.get('segment_position'), 'segment_position': metadata.get('segment_position'),
'index_node_hash': metadata.get('segment_index_node_hash'), 'index_node_hash': metadata.get('segment_index_node_hash'),
'content': context_dict.get('content'), 'content': context_dict.get('content'),
'page': metadata.get('page'),
} }
return source return source