Feat/add retriever rerank (#1560)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2023-11-17 22:13:37 +08:00
committed by GitHub
parent a4f37220a0
commit 4588831bff
44 changed files with 1899 additions and 164 deletions

View File

@@ -1,4 +1,6 @@
import json
import logging
import threading
import time
from typing import List
@@ -9,16 +11,26 @@ from langchain.schema import Document
from sklearn.manifold import TSNE
from core.embedding.cached_embedding import CacheEmbedding
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.account import Account
from models.dataset import Dataset, DocumentSegment, DatasetQuery
from services.retrieval_service import RetrievalService
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class HitTestingService:
@classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return {
"query": {
@@ -28,31 +40,68 @@ class HitTestingService:
"records": []
}
start = time.perf_counter()
# get retrieval model , if the model is not setting , using default
if not retrieval_model:
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
# get embedding model
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
embeddings = CacheEmbedding(embedding_model)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
all_documents = []
threads = []
# retrieval_model source with semantic
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'top_k': retrieval_model['top_k'],
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
'all_documents': all_documents,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
'top_k': retrieval_model['top_k'],
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
'all_documents': all_documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
if retrieval_model['search_method'] == 'hybrid_search':
hybrid_rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
model_name=retrieval_model['reranking_model']['reranking_model_name']
)
all_documents = hybrid_rerank.rerank(query, all_documents,
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
retrieval_model['top_k'])
start = time.perf_counter()
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 10,
'filter': {
'group_id': [dataset.id]
}
}
)
end = time.perf_counter()
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
@@ -67,7 +116,7 @@ class HitTestingService:
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(dataset, embeddings, query, documents)
return cls.compact_retrieve_response(dataset, embeddings, query, all_documents)
@classmethod
def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]):
@@ -99,7 +148,7 @@ class HitTestingService:
record = {
"segment": segment,
"score": document.metadata['score'],
"score": document.metadata.get('score', None),
"tsne_position": tsne_position_data[i]
}
@@ -136,3 +185,11 @@ class HitTestingService:
tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])})
return tsne_position_data
@classmethod
def hit_testing_args_check(cls, args):
query = args['query']
if not query or len(query) > 250:
raise ValueError('Query is required and cannot exceed 250 characters')