From 6a857e01f67ff4c7d0334d66b581724be6578dbb Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Wed, 26 Mar 2025 14:16:21 +0800 Subject: [PATCH] fix multiple metadata filter's confusing setting (#16771) --- api/core/rag/retrieval/dataset_retrieval.py | 33 +++++++++++---- .../knowledge_retrieval_node.py | 41 +++++++++++++------ 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index f88e8629f..21c561f69 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -850,8 +850,9 @@ class DatasetRetrieval: ) if automatic_metadata_filters: conditions = [] - for filter in automatic_metadata_filters: + for sequence, filter in enumerate(automatic_metadata_filters): self._process_metadata_filter_func( + sequence, filter.get("condition"), # type: ignore filter.get("metadata_name"), # type: ignore filter.get("value"), @@ -871,14 +872,18 @@ class DatasetRetrieval: elif metadata_filtering_mode == "manual": if metadata_filtering_conditions: metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump()) - for condition in metadata_filtering_conditions.conditions: # type: ignore + for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore metadata_name = condition.name expected_value = condition.value if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): if isinstance(expected_value, str): expected_value = self._replace_metadata_filter_value(expected_value, inputs) filters = self._process_metadata_filter_func( - condition.comparison_operator, metadata_name, expected_value, filters + sequence, + condition.comparison_operator, + metadata_name, + expected_value, + filters, ) else: raise ValueError("Invalid metadata filtering mode") @@ -960,26 +965,36 @@ class DatasetRetrieval: return None return automatic_metadata_filters - def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[Any], filters: list): + def _process_metadata_filter_func( + self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list + ): + key = f"{metadata_name}_{sequence}" + key_value = f"{metadata_name}_{sequence}_value" match condition: case "contains": filters.append( - (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%") + (text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params( + **{key: metadata_name, key_value: f"%{value}%"} + ) ) case "not contains": filters.append( - (text("documents.doc_metadata ->> :key NOT LIKE :value")).params( - key=metadata_name, value=f"%{value}%" + (text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params( + **{key: metadata_name, key_value: f"%{value}%"} ) ) case "start with": filters.append( - (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%") + (text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params( + **{key: metadata_name, key_value: f"{value}%"} + ) ) case "end with": filters.append( - (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}") + (text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params( + **{key: metadata_name, key_value: f"%{value}"} + ) ) case "is" | "=": if isinstance(value, str): diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index bb825e7d4..860373948 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -332,8 +332,9 @@ class KnowledgeRetrievalNode(LLMNode): automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data) if automatic_metadata_filters: conditions = [] - for filter in automatic_metadata_filters: + for sequence, filter in enumerate(automatic_metadata_filters): self._process_metadata_filter_func( + sequence, filter.get("condition", ""), filter.get("metadata_name", ""), filter.get("value"), @@ -354,7 +355,7 @@ class KnowledgeRetrievalNode(LLMNode): if node_data.metadata_filtering_conditions: metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump()) if node_data.metadata_filtering_conditions: - for condition in node_data.metadata_filtering_conditions.conditions: # type: ignore + for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore metadata_name = condition.name expected_value = condition.value if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): @@ -362,14 +363,18 @@ class KnowledgeRetrievalNode(LLMNode): expected_value = self.graph_runtime_state.variable_pool.convert_template( expected_value ).value[0] - if expected_value.value_type == "number": - expected_value = expected_value.value - elif expected_value.value_type == "string": - expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() + if expected_value.value_type == "number": # type: ignore + expected_value = expected_value.value # type: ignore + elif expected_value.value_type == "string": # type: ignore + expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore else: raise ValueError("Invalid expected metadata value type") filters = self._process_metadata_filter_func( - condition.comparison_operator, metadata_name, expected_value, filters + sequence, + condition.comparison_operator, + metadata_name, + expected_value, + filters, ) else: raise ValueError("Invalid metadata filtering mode") @@ -448,25 +453,35 @@ class KnowledgeRetrievalNode(LLMNode): return [] return automatic_metadata_filters - def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[str], filters: list): + def _process_metadata_filter_func( + self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list + ): + key = f"{metadata_name}_{sequence}" + key_value = f"{metadata_name}_{sequence}_value" match condition: case "contains": filters.append( - (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%") + (text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params( + **{key: metadata_name, key_value: f"%{value}%"} + ) ) case "not contains": filters.append( - (text("documents.doc_metadata ->> :key NOT LIKE :value")).params( - key=metadata_name, value=f"%{value}%" + (text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params( + **{key: metadata_name, key_value: f"%{value}%"} ) ) case "start with": filters.append( - (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%") + (text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params( + **{key: metadata_name, key_value: f"{value}%"} + ) ) case "end with": filters.append( - (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}") + (text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params( + **{key: metadata_name, key_value: f"%{value}"} + ) ) case "=" | "is": if isinstance(value, str):