fix: metadata filtering condition variable unassigned; fix External K… (#19208)

This commit is contained in:
Will
2025-05-07 14:52:09 +08:00
committed by GitHub
parent d1c08a810b
commit bfa652f2d0
10 changed files with 101 additions and 40 deletions

View File

@@ -1,11 +1,12 @@
from typing import Any
from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
@@ -34,7 +35,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
dataset_id: str
metadata_filtering_conditions: MetadataCondition
user_id: Optional[str] = None
retrieve_config: DatasetRetrieveConfigEntity
inputs: dict
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
@@ -48,7 +51,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description,
metadata_filtering_conditions=MetadataCondition(),
**kwargs,
)
@@ -61,6 +63,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return ""
for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id)
dataset_retrieval = DatasetRetrieval()
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
[dataset.id],
query,
self.tenant_id,
self.user_id or "unknown",
cast(str, self.retrieve_config.metadata_filtering_mode),
cast(ModelConfig, self.retrieve_config.metadata_model_config),
self.retrieve_config.metadata_filtering_conditions,
self.inputs,
)
if metadata_filter_document_ids:
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
else:
document_ids_filter = None
if dataset.provider == "external":
results = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
@@ -68,7 +85,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
dataset_id=dataset.id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
metadata_condition=self.metadata_filtering_conditions,
metadata_condition=metadata_condition,
)
for external_document in external_documents:
document = RetrievalDocument(
@@ -104,12 +121,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return str("\n".join([item.page_content for item in results]))
else:
if metadata_condition and not document_ids_filter:
return ""
# get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
retrieval_method="keyword_search",
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
document_ids_filter=document_ids_filter,
)
return str("\n".join([document.page_content for document in documents]))
else:
@@ -128,6 +151,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights"),
document_ids_filter=document_ids_filter,
)
else:
documents = []

View File

@@ -34,6 +34,8 @@ class DatasetRetrieverTool(Tool):
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
) -> list["DatasetRetrieverTool"]:
"""
get dataset tool
@@ -57,6 +59,8 @@ class DatasetRetrieverTool(Tool):
return_resource=return_resource,
invoke_from=invoke_from,
hit_callback=hit_callback,
user_id=user_id,
inputs=inputs,
)
if retrieval_tools is None or len(retrieval_tools) == 0:
return []