feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -6,7 +6,7 @@ import time
import uuid
from typing import Any, Optional
from flask_login import current_user
from flask_login import current_user # type: ignore
from sqlalchemy import func
from werkzeug.exceptions import NotFound
@@ -186,8 +186,9 @@ class DatasetService:
return dataset
@staticmethod
def get_dataset(dataset_id) -> Dataset:
return Dataset.query.filter_by(id=dataset_id).first()
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first()
return dataset
@staticmethod
def check_dataset_model_setting(dataset):
@@ -228,6 +229,8 @@ class DatasetService:
@staticmethod
def update_dataset(dataset_id, data, user):
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise ValueError("Dataset not found")
DatasetService.check_dataset_permission(dataset, user)
if dataset.provider == "external":
@@ -371,7 +374,13 @@ class DatasetService:
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None):
def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None):
if not dataset:
raise ValueError("Dataset not found")
if not user:
raise ValueError("User not found")
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
if dataset.created_by != user.id:
raise NoPermissionError("You do not have permission to access this dataset.")
@@ -765,6 +774,11 @@ class DocumentService:
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
else:
logging.warn(
f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule"
)
return
db.session.add(dataset_process_rule)
db.session.commit()
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
@@ -1009,9 +1023,10 @@ class DocumentService:
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
db.session.add(dataset_process_rule)
db.session.commit()
document.dataset_process_rule_id = dataset_process_rule.id
if dataset_process_rule is not None:
db.session.add(dataset_process_rule)
db.session.commit()
document.dataset_process_rule_id = dataset_process_rule.id
# update document data source
if document_data.get("data_source"):
file_name = ""
@@ -1554,7 +1569,7 @@ class SegmentService:
segment.word_count = len(content)
if document.doc_form == "qa_model":
segment.answer = segment_update_entity.answer
segment.word_count += len(segment_update_entity.answer)
segment.word_count += len(segment_update_entity.answer or "")
word_count_change = segment.word_count - word_count_change
if segment_update_entity.keywords:
segment.keywords = segment_update_entity.keywords
@@ -1569,7 +1584,8 @@ class SegmentService:
db.session.add(document)
# update segment index task
if segment_update_entity.enabled:
VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset)
keywords = segment_update_entity.keywords or []
VectorService.create_segments_vector([keywords], [segment], dataset)
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
@@ -1601,7 +1617,7 @@ class SegmentService:
segment.disabled_by = None
if document.doc_form == "qa_model":
segment.answer = segment_update_entity.answer
segment.word_count += len(segment_update_entity.answer)
segment.word_count += len(segment_update_entity.answer or "")
word_count_change = segment.word_count - word_count_change
# update document word count
if word_count_change != 0:
@@ -1619,8 +1635,8 @@ class SegmentService:
segment.status = "error"
segment.error = str(e)
db.session.commit()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
return segment
new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
return new_segment
@classmethod
def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset):
@@ -1680,6 +1696,8 @@ class DatasetCollectionBindingService:
.order_by(DatasetCollectionBinding.created_at)
.first()
)
if not dataset_collection_binding:
raise ValueError("Dataset collection binding not found")
return dataset_collection_binding