diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index c4a1e9f05..0d400000d 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -5,7 +5,9 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional from flask import Flask, current_app +from sqlalchemy import and_, or_ from sqlalchemy.orm import load_only +from sqlalchemy.sql.expression import false from configs import dify_config from core.rag.data_post_processor.data_post_processor import DataPostProcessor @@ -315,17 +317,29 @@ class RetrievalService: child_chunks = db.session.query(ChildChunk).filter(ChildChunk.index_node_id.in_(child_index_node_ids)).all() child_chunk_map = {chunk.index_node_id: chunk for chunk in child_chunks} - # Batch query DocumentSegment with unified conditions + segment_ids_from_child = [chunk.segment_id for chunk in child_chunks] + segment_conditions = [] + + if index_node_ids: + segment_conditions.append(DocumentSegment.index_node_id.in_(index_node_ids)) + + if segment_ids_from_child: + segment_conditions.append(DocumentSegment.id.in_(segment_ids_from_child)) + + if segment_conditions: + filter_expr = or_(*segment_conditions) + else: + filter_expr = false() + segment_map = { segment.id: segment for segment in db.session.query(DocumentSegment) .filter( - ( - DocumentSegment.index_node_id.in_(index_node_ids) - | DocumentSegment.id.in_([chunk.segment_id for chunk in child_chunks]) - ), - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", + and_( + filter_expr, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + ) ) .options( load_only(