Support knowledge metadata filter (#15982)

This commit is contained in:
Jyong
2025-03-18 16:42:19 +08:00
committed by GitHub
parent b65f2eb55f
commit abeaea4f79
48 changed files with 2502 additions and 574 deletions

View File

@@ -1,8 +1,10 @@
from collections.abc import Sequence
from typing import Any, Literal, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import VisionConfig
class RerankingModelConfig(BaseModel):
@@ -73,6 +75,48 @@ class SingleRetrievalConfig(BaseModel):
model: ModelConfig
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
# for number
"=",
"",
">",
"<",
"",
"",
# for time
"before",
"after",
]
class Condition(BaseModel):
"""
Conditon detail
"""
name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None
class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
class KnowledgeRetrievalNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
@@ -84,3 +128,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
vision: VisionConfig = Field(default_factory=VisionConfig)

View File

@@ -16,3 +16,7 @@ class ModelNotSupportedError(KnowledgeRetrievalNodeError):
class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
"""Raised when the model provider quota is exceeded."""
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
"""Raised when the model is not a Large Language Model."""

View File

@@ -1,32 +1,51 @@
import json
import logging
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Any, Optional, cast
from sqlalchemy import func
from sqlalchemy import Integer, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_3,
)
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document, RateLimitLog
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from models.workflow import WorkflowNodeExecutionStatus
from services.feature_service import FeatureService
from .entities import KnowledgeRetrievalNodeData
from .entities import KnowledgeRetrievalNodeData, ModelConfig
from .exc import (
InvalidModelTypeError,
KnowledgeRetrievalNodeError,
ModelCredentialsNotInitializedError,
ModelNotExistError,
@@ -45,13 +64,14 @@ default_retrieval_model = {
}
class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
_node_data_cls = KnowledgeRetrievalNodeData
class KnowledgeRetrievalNode(LLMNode):
_node_data_cls = KnowledgeRetrievalNodeData # type: ignore
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
def _run(self) -> NodeRunResult:
def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
# extract variables
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@@ -91,7 +111,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
# retrieve knowledge
try:
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
outputs = {"result": results}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
@@ -145,11 +165,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
if not dataset:
continue
available_datasets.append(dataset)
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data)
model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
@@ -174,6 +197,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
model_config=model_config,
model_instance=model_instance,
planning_strategy=planning_strategy,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
if node_data.multiple_retrieval_config is None:
@@ -220,6 +245,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
reranking_model=reranking_model,
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
@@ -287,13 +314,187 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
item["metadata"]["position"] = position
return retrieval_resource_list
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
document_query = db.session.query(Document).filter(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
filters = [] # type: ignore
metadata_condition = None
if node_data.metadata_filtering_mode == "disabled":
return None, None
elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
if automatic_metadata_filters:
conditions = []
for filter in automatic_metadata_filters:
self._process_metadata_filter_func(
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters, # type: ignore
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator, # type: ignore
conditions=conditions,
)
elif node_data.metadata_filtering_mode == "manual":
if node_data.metadata_filtering_conditions:
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
if node_data.metadata_filtering_conditions:
for condition in node_data.metadata_filtering_conditions.conditions: # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value or condition.comparison_operator in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).text
filters = self._process_metadata_filter_func(
condition.comparison_operator, metadata_name, expected_value, filters
)
else:
raise ValueError("Invalid metadata filtering mode")
if filters:
if node_data.metadata_filtering_conditions.logical_operator == "and": # type: ignore
document_query = document_query.filter(and_(*filters))
else:
document_query = document_query.filter(or_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.field_name for metadata_field in metadata_fields]
# get metadata model config
metadata_model_config = node_data.metadata_model_config
if metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore
# fetch prompt messages
prompt_template = self._get_prompt_template(
node_data=node_data,
metadata_fields=all_metadata_fields,
query=query or "",
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
sys_query=query,
memory=None,
model_config=model_config,
sys_files=[],
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
)
result_text = ""
try:
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.metadata_model_config, # type: ignore
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
break
result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = []
if "metadata_map" in result_text_json:
metadata_map = result_text_json["metadata_map"]
for item in metadata_map:
if item.get("metadata_field_name") in all_metadata_fields:
automatic_metadata_filters.append(
{
"metadata_name": item.get("metadata_field_name"),
"value": item.get("metadata_field_value"),
"condition": item.get("comparison_operator"),
}
)
except Exception as e:
return []
return automatic_metadata_filters
def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[str], filters: list):
match condition:
case "contains":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%")
)
case "not contains":
filters.append(
(text("documents.doc_metadata ->> :key NOT LIKE :value")).params(
key=metadata_name, value=f"%{value}%"
)
)
case "start with":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%")
)
case "end with":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}")
)
case "=" | "is":
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
else:
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) == value)
case "is not" | "":
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
else:
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) != value)
case "empty":
filters.append(Document.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(Document.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value)
case "after" | ">":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value)
case "" | ">=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value)
case "" | ">=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value)
case _:
pass
return filters
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: KnowledgeRetrievalNodeData,
node_data: KnowledgeRetrievalNodeData, # type: ignore
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -306,18 +507,16 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
return variable_mapping
def _fetch_model_config(
self, node_data: KnowledgeRetrievalNodeData
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore
"""
Fetch model config
:param node_data: node data
:param model: model
:return:
"""
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
model_name = node_data.single_retrieval_config.model.name
provider_name = node_data.single_retrieval_config.model.provider
if model is None:
raise ValueError("model is required")
model_name = model.name
provider_name = model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
@@ -346,14 +545,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = node_data.single_retrieval_config.model.completion_params
completion_params = model.completion_params
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = node_data.single_retrieval_config.model.mode
model_mode = model.mode
if not model_mode:
raise ModelNotExistError("LLM mode is required.")
@@ -372,3 +571,50 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
parameters=completion_params,
stop=stop,
)
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore
input_text = query
memory_str = ""
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER,
text=METADATA_FILTER_USER_PROMPT_3.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
),
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=METADATA_FILTER_COMPLETION_PROMPT.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
)
)
else:
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")

View File

@@ -0,0 +1,66 @@
METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
""" # noqa: E501
METADATA_FILTER_USER_PROMPT_1 = """
{ "input_text": "I want to know which companys email address test@example.com is?",
"metadata_fields": ["filename", "email", "phone", "address"]
}
"""
METADATA_FILTER_ASSISTANT_PROMPT_1 = """
```json
{"metadata_map": [
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
]
}
```
"""
METADATA_FILTER_USER_PROMPT_2 = """
{"input_text": "What are the movies with a score of more than 9 in 2024?",
"metadata_fields": ["name", "year", "rating", "country"]}
"""
METADATA_FILTER_ASSISTANT_PROMPT_2 = """
```json
{"metadata_map": [
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
]}
```
"""
METADATA_FILTER_USER_PROMPT_3 = """
'{{"input_text": "{input_text}",',
'"metadata_fields": {metadata_fields}}}'
"""
METADATA_FILTER_COMPLETION_PROMPT = """
### Job Description
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I want to know which companys email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
</example>
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output
""" # noqa: E501