more assert (#24996)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-09-08 10:59:43 +09:00
committed by GitHub
parent 98204d78fb
commit 16a3e21410
17 changed files with 235 additions and 90 deletions

View File

@@ -8,7 +8,7 @@ import uuid
from collections import Counter
from typing import Any, Literal, Optional
from flask_login import current_user
import sqlalchemy as sa
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@@ -27,6 +27,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@@ -498,8 +499,11 @@ class DatasetService:
data: Update data dictionary
filtered_data: Filtered update data to modify
"""
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
try:
model_manager = ModelManager()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
@@ -611,8 +615,12 @@ class DatasetService:
data: Update data dictionary
filtered_data: Filtered update data to modify
"""
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
model_manager = ModelManager()
try:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
@@ -720,6 +728,8 @@ class DatasetService:
@staticmethod
def get_dataset_auto_disable_logs(dataset_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
return {
@@ -924,6 +934,8 @@ class DocumentService:
@staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
assert isinstance(current_user, Account)
documents = (
db.session.query(Document)
.where(
@@ -983,6 +995,8 @@ class DocumentService:
@staticmethod
def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
assert isinstance(current_user, Account)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise ValueError("Dataset not found.")
@@ -1012,6 +1026,7 @@ class DocumentService:
if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}:
raise DocumentIndexingError()
# update document to be paused
assert current_user is not None
document.is_paused = True
document.paused_by = current_user.id
document.paused_at = naive_utc_now()
@@ -1098,6 +1113,9 @@ class DocumentService:
# check doc_form
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
# check document limit
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
@@ -1434,6 +1452,8 @@ class DocumentService:
@staticmethod
def get_tenant_documents_count():
assert isinstance(current_user, Account)
documents_count = (
db.session.query(Document)
.where(
@@ -1454,6 +1474,8 @@ class DocumentService:
dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = "web",
):
assert isinstance(current_user, Account)
DatasetService.check_dataset_model_setting(dataset)
document = DocumentService.get_document(dataset.id, document_data.original_document_id)
if document is None:
@@ -1513,7 +1535,7 @@ class DocumentService:
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
db.and_(
sa.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
@@ -1574,6 +1596,9 @@ class DocumentService:
@staticmethod
def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
@@ -2013,6 +2038,9 @@ class SegmentService:
@classmethod
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
content = args["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
@@ -2075,6 +2103,9 @@ class SegmentService:
@classmethod
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
lock_name = f"multi_add_segment_lock_document_id_{document.id}"
increment_word_count = 0
with redis_client.lock(lock_name, timeout=600):
@@ -2158,6 +2189,9 @@ class SegmentService:
@classmethod
def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
@@ -2349,6 +2383,7 @@ class SegmentService:
@classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
segments = (
db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count)
.where(
@@ -2379,6 +2414,8 @@ class SegmentService:
def update_segments_status(
cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document
):
assert current_user is not None
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
@@ -2441,6 +2478,8 @@ class SegmentService:
def create_child_chunk(
cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset
) -> ChildChunk:
assert isinstance(current_user, Account)
lock_name = f"add_child_lock_{segment.id}"
with redis_client.lock(lock_name, timeout=20):
index_node_id = str(uuid.uuid4())
@@ -2488,6 +2527,8 @@ class SegmentService:
document: Document,
dataset: Dataset,
) -> list[ChildChunk]:
assert isinstance(current_user, Account)
child_chunks = (
db.session.query(ChildChunk)
.where(
@@ -2562,6 +2603,8 @@ class SegmentService:
document: Document,
dataset: Dataset,
) -> ChildChunk:
assert current_user is not None
try:
child_chunk.content = content
child_chunk.word_count = len(content)
@@ -2592,6 +2635,8 @@ class SegmentService:
def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
):
assert isinstance(current_user, Account)
query = (
select(ChildChunk)
.filter_by(