update knowledge base api (#20426)

This commit is contained in:
Dongyu Li
2025-05-30 14:45:30 +08:00
committed by GitHub
parent 55371e5abf
commit 1ea4459d9f
6 changed files with 397 additions and 36 deletions

View File

@@ -2,8 +2,11 @@ import logging
import time
from typing import Any
from core.app.app_config.entities import ModelConfig
from core.model_runtime.entities import LLMMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.account import Account
@@ -34,7 +37,29 @@ class HitTestingService:
# get retrieval model , if the model is not setting , using default
if not retrieval_model:
retrieval_model = dataset.retrieval_model or default_retrieval_model
document_ids_filter = None
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
if metadata_filtering_conditions:
dataset_retrieval = DatasetRetrieval()
from core.app.app_config.entities import MetadataFilteringCondition
metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions)
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
dataset_ids=[dataset.id],
query=query,
metadata_filtering_mode="manual",
metadata_filtering_conditions=metadata_filtering_conditions,
inputs={},
tenant_id="",
user_id="",
metadata_model_config=ModelConfig(provider="", name="", mode=LLMMode.CHAT, completion_params={}),
)
if metadata_filter_document_ids:
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
if metadata_condition and not document_ids_filter:
return cls.compact_retrieve_response(query, [])
all_documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
@@ -48,6 +73,7 @@ class HitTestingService:
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
)
end = time.perf_counter()
@@ -99,7 +125,7 @@ class HitTestingService:
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
@classmethod
def compact_retrieve_response(cls, query: str, documents: list[Document]):
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]:
records = RetrievalService.format_retrieval_documents(documents)
return {