Feat/delete single dataset retrival (#6570)

This commit is contained in:
Jyong
2024-07-24 12:50:11 +08:00
committed by GitHub
parent 0fb741f269
commit e4bb943fe5
22 changed files with 651 additions and 115 deletions

View File

@@ -6,6 +6,7 @@ from flask import Flask, current_app
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.rerank.constants.rerank_mode import RerankMode
from core.rag.retrieval.retrival_methods import RetrievalMethod
from extensions.ext_database import db
from models.dataset import Dataset
@@ -26,13 +27,19 @@ class RetrievalService:
@classmethod
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
top_k: int, score_threshold: Optional[float] = .0,
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = None,
weights: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = []
keyword_search_documents = []
embedding_search_documents = []
full_text_search_documents = []
hybrid_search_documents = []
threads = []
exceptions = []
# retrieval_model source with keyword
@@ -87,7 +94,8 @@ class RetrievalService:
raise Exception(exception_message)
if retrival_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
reranking_model, weights, False)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
@@ -143,7 +151,9 @@ class RetrievalService:
if documents:
if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
RerankMode.RERANKING_MODEL.value,
reranking_model, None, False)
all_documents.extend(data_post_processor.invoke(
query=query,
documents=documents,
@@ -175,7 +185,9 @@ class RetrievalService:
)
if documents:
if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
RerankMode.RERANKING_MODEL.value,
reranking_model, None, False)
all_documents.extend(data_post_processor.invoke(
query=query,
documents=documents,