Feat/support parent child chunk (#12092)
This commit is contained in:
@@ -17,6 +17,7 @@ from sqlalchemy.dialects.postgresql import JSONB
|
||||
from configs import dify_config
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_storage import storage
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||
|
||||
from .account import Account
|
||||
from .engine import db
|
||||
@@ -215,7 +216,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
MODES = ["automatic", "custom"]
|
||||
MODES = ["automatic", "custom", "hierarchical"]
|
||||
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
|
||||
AUTOMATIC_RULES: dict[str, Any] = {
|
||||
"pre_processing_rules": [
|
||||
@@ -231,8 +232,6 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
|
||||
"dataset_id": self.dataset_id,
|
||||
"mode": self.mode,
|
||||
"rules": self.rules_dict,
|
||||
"created_by": self.created_by,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -396,6 +395,12 @@ class Document(db.Model): # type: ignore[name-defined]
|
||||
.scalar()
|
||||
)
|
||||
|
||||
@property
|
||||
def process_rule_dict(self):
|
||||
if self.dataset_process_rule_id:
|
||||
return self.dataset_process_rule.to_dict()
|
||||
return None
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
@@ -560,6 +565,24 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
|
||||
.first()
|
||||
)
|
||||
|
||||
@property
|
||||
def child_chunks(self):
|
||||
process_rule = self.document.dataset_process_rule
|
||||
if process_rule.mode == "hierarchical":
|
||||
rules = Rule(**process_rule.rules_dict)
|
||||
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.filter(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
return child_chunks or []
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_sign_content(self):
|
||||
signed_urls = []
|
||||
text = self.content
|
||||
@@ -605,6 +628,47 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
|
||||
return text
|
||||
|
||||
|
||||
class ChildChunk(db.Model):
|
||||
__tablename__ = "child_chunks"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
|
||||
db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
|
||||
)
|
||||
|
||||
# initial fields
|
||||
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
dataset_id = db.Column(StringUUID, nullable=False)
|
||||
document_id = db.Column(StringUUID, nullable=False)
|
||||
segment_id = db.Column(StringUUID, nullable=False)
|
||||
position = db.Column(db.Integer, nullable=False)
|
||||
content = db.Column(db.Text, nullable=False)
|
||||
word_count = db.Column(db.Integer, nullable=False)
|
||||
# indexing fields
|
||||
index_node_id = db.Column(db.String(255), nullable=True)
|
||||
index_node_hash = db.Column(db.String(255), nullable=True)
|
||||
type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
indexing_at = db.Column(db.DateTime, nullable=True)
|
||||
completed_at = db.Column(db.DateTime, nullable=True)
|
||||
error = db.Column(db.Text, nullable=True)
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
|
||||
|
||||
@property
|
||||
def document(self):
|
||||
return db.session.query(Document).filter(Document.id == self.document_id).first()
|
||||
|
||||
@property
|
||||
def segment(self):
|
||||
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
|
||||
|
||||
|
||||
class AppDatasetJoin(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "app_dataset_joins"
|
||||
__table_args__ = (
|
||||
@@ -844,3 +908,20 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class DatasetAutoDisableLog(db.Model):
|
||||
__tablename__ = "dataset_auto_disable_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
|
||||
db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
|
||||
db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
|
||||
db.Index("dataset_auto_disable_log_created_atx", "created_at"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
dataset_id = db.Column(StringUUID, nullable=False)
|
||||
document_id = db.Column(StringUUID, nullable=False)
|
||||
notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
Reference in New Issue
Block a user