fix: metadata filtering condition variable unassigned; fix External K… (#19208)
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
from typing import Any
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
@@ -34,7 +35,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
dataset_id: str
|
||||
metadata_filtering_conditions: MetadataCondition
|
||||
user_id: Optional[str] = None
|
||||
retrieve_config: DatasetRetrieveConfigEntity
|
||||
inputs: dict
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
@@ -48,7 +51,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
description=description,
|
||||
metadata_filtering_conditions=MetadataCondition(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -61,6 +63,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
return ""
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
||||
[dataset.id],
|
||||
query,
|
||||
self.tenant_id,
|
||||
self.user_id or "unknown",
|
||||
cast(str, self.retrieve_config.metadata_filtering_mode),
|
||||
cast(ModelConfig, self.retrieve_config.metadata_model_config),
|
||||
self.retrieve_config.metadata_filtering_conditions,
|
||||
self.inputs,
|
||||
)
|
||||
if metadata_filter_document_ids:
|
||||
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
|
||||
else:
|
||||
document_ids_filter = None
|
||||
if dataset.provider == "external":
|
||||
results = []
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
@@ -68,7 +85,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
metadata_condition=self.metadata_filtering_conditions,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
for external_document in external_documents:
|
||||
document = RetrievalDocument(
|
||||
@@ -104,12 +121,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
|
||||
return str("\n".join([item.page_content for item in results]))
|
||||
else:
|
||||
if metadata_condition and not document_ids_filter:
|
||||
return ""
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
|
||||
retrieval_method="keyword_search",
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
@@ -128,6 +151,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights"),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
|
@@ -34,6 +34,8 @@ class DatasetRetrieverTool(Tool):
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
user_id: str,
|
||||
inputs: dict,
|
||||
) -> list["DatasetRetrieverTool"]:
|
||||
"""
|
||||
get dataset tool
|
||||
@@ -57,6 +59,8 @@ class DatasetRetrieverTool(Tool):
|
||||
return_resource=return_resource,
|
||||
invoke_from=invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
)
|
||||
if retrieval_tools is None or len(retrieval_tools) == 0:
|
||||
return []
|
||||
|
Reference in New Issue
Block a user