fix: replace all dataset.Model.query to db.session.query(Model) (#19509)

This commit is contained in:
非法操作
2025-05-12 13:52:33 +08:00
committed by GitHub
parent 49af07f444
commit b00f94df64
21 changed files with 430 additions and 265 deletions

View File

@@ -9,7 +9,7 @@ from collections import Counter
from typing import Any, Optional
from flask_login import current_user
from sqlalchemy import func
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@@ -77,11 +77,13 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
if user:
# get permitted dataset ids
dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all()
dataset_permission = (
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
)
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
@@ -129,7 +131,7 @@ class DatasetService:
else:
return [], 0
datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
return datasets.items, datasets.total
@@ -153,9 +155,10 @@ class DatasetService:
@staticmethod
def get_datasets_by_ids(ids, tenant_id):
datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate(
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False
)
stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
return datasets.items, datasets.total
@staticmethod
@@ -174,7 +177,7 @@ class DatasetService:
retrieval_model: Optional[RetrievalModel] = None,
):
# check if dataset name already exists
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
embedding_model = None
if indexing_technique == "high_quality":
@@ -235,7 +238,7 @@ class DatasetService:
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first()
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset
@staticmethod
@@ -436,7 +439,7 @@ class DatasetService:
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
dataset.query.filter_by(id=dataset_id).update(filtered_data)
db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
db.session.commit()
if action:
@@ -460,7 +463,7 @@ class DatasetService:
@staticmethod
def dataset_use_check(dataset_id) -> bool:
count = AppDatasetJoin.query.filter_by(dataset_id=dataset_id).count()
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
if count > 0:
return True
return False
@@ -475,7 +478,9 @@ class DatasetService:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
if dataset.permission == "partial_members":
user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first()
user_permission = (
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
)
if (
not user_permission
and dataset.tenant_id != user.current_tenant_id
@@ -499,23 +504,24 @@ class DatasetService:
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
if not any(
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all()
dp.dataset_id == dataset.id
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
):
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
dataset_queries = (
DatasetQuery.query.filter_by(dataset_id=dataset_id)
.order_by(db.desc(DatasetQuery.created_at))
.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
)
stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
return dataset_queries.items, dataset_queries.total
@staticmethod
def get_related_apps(dataset_id: str):
return (
AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id)
db.session.query(AppDatasetJoin)
.filter(AppDatasetJoin.dataset_id == dataset_id)
.order_by(db.desc(AppDatasetJoin.created_at))
.all()
)
@@ -530,10 +536,14 @@ class DatasetService:
}
# get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
).all()
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog)
.filter(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
)
.all()
)
if dataset_auto_disable_logs:
return {
"document_ids": [log.document_id for log in dataset_auto_disable_logs],
@@ -873,7 +883,9 @@ class DocumentService:
@staticmethod
def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
document = (
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
)
if document:
return document.position + 1
else:
@@ -1010,13 +1022,17 @@ class DocumentService:
}
# check duplicate
if knowledge_config.duplicate:
document = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
).first()
document = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
)
.first()
)
if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
@@ -1054,12 +1070,16 @@ class DocumentService:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
).all()
documents = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
.all()
)
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
@@ -1206,12 +1226,16 @@ class DocumentService:
@staticmethod
def get_tenant_documents_count():
documents_count = Document.query.filter(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id,
).count()
documents_count = (
db.session.query(Document)
.filter(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id,
)
.count()
)
return documents_count
@staticmethod
@@ -1328,7 +1352,7 @@ class DocumentService:
db.session.commit()
# update document segment
update_params = {DocumentSegment.status: "re_segment"}
DocumentSegment.query.filter_by(document_id=document.id).update(update_params)
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params)
db.session.commit()
# trigger async task
document_indexing_update_task.delay(document.dataset_id, document.id)
@@ -1918,7 +1942,8 @@ class SegmentService:
@classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
index_node_ids = (
DocumentSegment.query.with_entities(DocumentSegment.index_node_id)
db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
@@ -2157,20 +2182,28 @@ class SegmentService:
def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
):
query = ChildChunk.query.filter_by(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset_id,
document_id=document_id,
segment_id=segment_id,
).order_by(ChildChunk.position.asc())
query = (
select(ChildChunk)
.filter_by(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset_id,
document_id=document_id,
segment_id=segment_id,
)
.order_by(ChildChunk.position.asc())
)
if keyword:
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@classmethod
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
"""Get a child chunk by its ID."""
result = ChildChunk.query.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).first()
result = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, ChildChunk) else None
@classmethod
@@ -2184,7 +2217,7 @@ class SegmentService:
limit: int = 20,
):
"""Get segments for a document with optional filtering."""
query = DocumentSegment.query.filter(
query = select(DocumentSegment).filter(
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
)
@@ -2194,9 +2227,8 @@ class SegmentService:
if keyword:
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate(
page=page, per_page=limit, max_per_page=100, error_out=False
)
query = query.order_by(DocumentSegment.position.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
return paginated_segments.items, paginated_segments.total
@@ -2236,9 +2268,11 @@ class SegmentService:
raise ValueError(ex.description)
# check segment
segment = DocumentSegment.query.filter(
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
@@ -2251,9 +2285,11 @@ class SegmentService:
@classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
"""Get a segment by its ID."""
result = DocumentSegment.query.filter(
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id
).first()
result = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, DocumentSegment) else None