Support knowledge metadata filter (#15982)

This commit is contained in:
Jyong
2025-03-18 16:42:19 +08:00
committed by GitHub
parent b65f2eb55f
commit abeaea4f79
48 changed files with 2502 additions and 574 deletions

View File

@@ -1,35 +1,61 @@
import json
import math
import re
import threading
from collections import Counter
from typing import Any, Optional, cast
from collections import Counter, defaultdict
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union, cast
from flask import Flask, current_app
from sqlalchemy import Integer, and_, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.app.app_config.entities import (
DatasetEntity,
DatasetRetrieveConfigEntity,
MetadataFilteringCondition,
ModelConfig,
)
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
from core.rag.retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DatasetQuery, DocumentSegment
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
@@ -59,6 +85,7 @@ class DatasetRetrieval:
hit_callback: DatasetIndexToolCallbackHandler,
message_id: str,
memory: Optional[TokenBufferMemory] = None,
inputs: Optional[Mapping[str, Any]] = None,
) -> Optional[str]:
"""
Retrieve dataset.
@@ -116,6 +143,22 @@ class DatasetRetrieval:
continue
available_datasets.append(dataset)
if inputs:
inputs = {key: str(value) for key, value in inputs.items()}
else:
inputs = {}
available_datasets_ids = [dataset.id for dataset in available_datasets]
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
available_datasets_ids,
query,
tenant_id,
user_id,
retrieve_config.metadata_filtering_mode, # type: ignore
retrieve_config.metadata_model_config, # type: ignore
retrieve_config.metadata_filtering_conditions,
inputs,
)
all_documents = []
user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
@@ -130,6 +173,8 @@ class DatasetRetrieval:
model_config,
planning_strategy,
message_id,
metadata_filter_document_ids,
metadata_condition,
)
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
all_documents = self.multiple_retrieve(
@@ -146,6 +191,8 @@ class DatasetRetrieval:
retrieve_config.weights,
retrieve_config.reranking_enabled or True,
message_id,
metadata_filter_document_ids,
metadata_condition,
)
dify_documents = [item for item in all_documents if item.provider == "dify"]
@@ -239,6 +286,8 @@ class DatasetRetrieval:
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
message_id: Optional[str] = None,
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
metadata_condition: Optional[MetadataCondition] = None,
):
tools = []
for dataset in available_datasets:
@@ -279,6 +328,7 @@ class DatasetRetrieval:
dataset_id=dataset_id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
metadata_condition=metadata_condition,
)
for external_document in external_documents:
document = Document(
@@ -293,6 +343,15 @@ class DatasetRetrieval:
document.metadata["dataset_name"] = dataset.name
results.append(document)
else:
if metadata_condition and not metadata_filter_document_ids:
return []
document_ids_filter = None
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
return []
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
# get top k
@@ -324,6 +383,7 @@ class DatasetRetrieval:
reranking_model=reranking_model,
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
weights=retrieval_model_config.get("weights", None),
document_ids_filter=document_ids_filter,
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
@@ -348,6 +408,8 @@ class DatasetRetrieval:
weights: Optional[dict[str, Any]] = None,
reranking_enable: bool = True,
message_id: Optional[str] = None,
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
metadata_condition: Optional[MetadataCondition] = None,
):
if not available_datasets:
return []
@@ -387,6 +449,16 @@ class DatasetRetrieval:
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
@@ -395,6 +467,8 @@ class DatasetRetrieval:
"query": query,
"top_k": top_k,
"all_documents": all_documents,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
},
)
threads.append(retrieval_thread)
@@ -493,7 +567,16 @@ class DatasetRetrieval:
db.session.add_all(dataset_queries)
db.session.commit()
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
def _retriever(
self,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
all_documents: list,
document_ids_filter: Optional[list[str]] = None,
metadata_condition: Optional[MetadataCondition] = None,
):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
@@ -506,6 +589,7 @@ class DatasetRetrieval:
dataset_id=dataset_id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
metadata_condition=metadata_condition,
)
for external_document in external_documents:
document = Document(
@@ -546,6 +630,7 @@ class DatasetRetrieval:
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
)
all_documents.extend(documents)
@@ -733,3 +818,340 @@ class DatasetRetrieval:
filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
)
return filter_documents[:top_k] if top_k else filter_documents
def _get_metadata_filter_condition(
self,
dataset_ids: list,
query: str,
tenant_id: str,
user_id: str,
metadata_filtering_mode: str,
metadata_model_config: ModelConfig,
metadata_filtering_conditions: Optional[MetadataFilteringCondition],
inputs: dict,
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
document_query = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id.in_(dataset_ids),
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
filters = [] # type: ignore
metadata_condition = None
if metadata_filtering_mode == "disabled":
return None, None
elif metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(
dataset_ids, query, tenant_id, user_id, metadata_model_config
)
if automatic_metadata_filters:
conditions = []
for filter in automatic_metadata_filters:
self._process_metadata_filter_func(
filter.get("condition"), # type: ignore
filter.get("metadata_name"), # type: ignore
filter.get("value"),
filters, # type: ignore
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=metadata_filtering_conditions.logical_operator, # type: ignore
conditions=conditions,
)
elif metadata_filtering_mode == "manual":
if metadata_filtering_conditions:
metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
for condition in metadata_filtering_conditions.conditions: # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value or condition.comparison_operator in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self._replace_metadata_filter_value(expected_value, inputs)
filters = self._process_metadata_filter_func(
condition.comparison_operator, metadata_name, expected_value, filters
)
else:
raise ValueError("Invalid metadata filtering mode")
if filters:
if metadata_filtering_conditions.logical_operator == "or": # type: ignore
document_query = document_query.filter(or_(*filters))
else:
document_query = document_query.filter(and_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
def replacer(match):
key = match.group(1)
return str(inputs.get(key, f"{{{{{key}}}}}"))
pattern = re.compile(r"\{\{(\w+)\}\}")
return pattern.sub(replacer, text)
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> Optional[list[dict[str, Any]]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config
if metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
# fetch prompt messages
prompt_messages, stop = self._get_prompt_template(
model_config=model_config,
mode=metadata_model_config.mode,
metadata_fields=all_metadata_fields,
query=query or "",
)
result_text = ""
try:
# handle invoke result
invoke_result = cast(
Generator[LLMResult, None, None],
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_config.parameters,
stop=stop,
stream=True,
user=user_id,
),
)
# handle invoke result
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = []
if "metadata_map" in result_text_json:
metadata_map = result_text_json["metadata_map"]
for item in metadata_map:
if item.get("metadata_field_name") in all_metadata_fields:
automatic_metadata_filters.append(
{
"metadata_name": item.get("metadata_field_name"),
"value": item.get("metadata_field_value"),
"condition": item.get("comparison_operator"),
}
)
except Exception as e:
return None
return automatic_metadata_filters
def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[Any], filters: list):
match condition:
case "contains":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%")
)
case "not contains":
filters.append(
(text("documents.doc_metadata ->> :key NOT LIKE :value")).params(
key=metadata_name, value=f"%{value}%"
)
)
case "start with":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%")
)
case "end with":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}")
)
case "is" | "=":
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
else:
filters.append(
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value
)
case "is not" | "":
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
else:
filters.append(
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value
)
case "empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value)
case "after" | ">":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value)
case "" | ">=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value)
case "" | ">=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value)
case _:
pass
return filters
def _fetch_model_config(
self, tenant_id: str, model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data: node data
:return:
"""
if model is None:
raise ValueError("single_retrieval_config is required")
model_name = model.name
provider_name = model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name, model_type=ModelType.LLM
)
if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ValueError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ValueError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise ValueError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = model.completion_params
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
def _get_prompt_template(
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
):
model_mode = ModelMode.value_of(mode)
input_text = query
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.CHAT:
prompt_template = []
system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT)
prompt_template.append(system_prompt_messages)
user_prompt_message_1 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1)
prompt_template.append(user_prompt_message_1)
assistant_prompt_message_1 = ChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
)
prompt_template.append(assistant_prompt_message_1)
user_prompt_message_2 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2)
prompt_template.append(user_prompt_message_2)
assistant_prompt_message_2 = ChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
)
prompt_template.append(assistant_prompt_message_2)
user_prompt_message_3 = ChatModelMessage(
role=PromptMessageRole.USER,
text=METADATA_FILTER_USER_PROMPT_3.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
),
)
prompt_template.append(user_prompt_message_3)
elif model_mode == ModelMode.COMPLETION:
prompt_template = CompletionModelPromptTemplate(
text=METADATA_FILTER_COMPLETION_PROMPT.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
)
)
else:
raise ValueError(f"Model mode {model_mode} not support.")
prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query=query or "",
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config,
)
stop = model_config.stop
return prompt_messages, stop
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
model = None
prompt_messages: list[PromptMessage] = []
full_text = ""
usage = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
if not model:
model = result.model
if not prompt_messages:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
usage = result.delta.usage
if not usage:
usage = LLMUsage.empty_usage()
return full_text, usage

View File

@@ -0,0 +1,66 @@
METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
""" # noqa: E501
METADATA_FILTER_USER_PROMPT_1 = """
{ "input_text": "I want to know which companys email address test@example.com is?",
"metadata_fields": ["filename", "email", "phone", "address"]
}
"""
METADATA_FILTER_ASSISTANT_PROMPT_1 = """
```json
{"metadata_map": [
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
]
}
```
"""
METADATA_FILTER_USER_PROMPT_2 = """
{"input_text": "What are the movies with a score of more than 9 in 2024?",
"metadata_fields": ["name", "year", "rating", "country"]}
"""
METADATA_FILTER_ASSISTANT_PROMPT_2 = """
```json
{"metadata_map": [
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
]}
```
"""
METADATA_FILTER_USER_PROMPT_3 = """
'{{"input_text": "{input_text}",',
'"metadata_fields": {metadata_fields}}}'
"""
METADATA_FILTER_COMPLETION_PROMPT = """
### Job Description
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I want to know which companys email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
</example>
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output
""" # noqa: E501