diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index f648b06de..fc2cbba78 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2344,13 +2344,9 @@ class SegmentService: @classmethod def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): - # Check if segment_ids is not empty to avoid WHERE false condition - if not segment_ids or len(segment_ids) == 0: - return - index_node_ids = ( - db.session.query(DocumentSegment) - .with_entities(DocumentSegment.index_node_id) - .where( + segments = ( + db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count) + .filter( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, @@ -2358,7 +2354,15 @@ class SegmentService: ) .all() ) - index_node_ids = [index_node_id[0] for index_node_id in index_node_ids] + + if not segments: + return + + index_node_ids = [seg.index_node_id for seg in segments] + total_words = sum(seg.word_count for seg in segments) + + document.word_count -= total_words + db.session.add(document) delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()