[Chore/Refactor] Improve type annotations in models module (#25281)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
-LAN-
2025-09-08 09:42:27 +08:00
committed by GitHub
parent e1f871fefe
commit 9b8a03b53b
23 changed files with 332 additions and 251 deletions

View File

@@ -286,7 +286,7 @@ class DatasetProcessRule(Base):
"segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
}
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"dataset_id": self.dataset_id,
@@ -295,7 +295,7 @@ class DatasetProcessRule(Base):
}
@property
def rules_dict(self):
def rules_dict(self) -> dict[str, Any] | None:
try:
return json.loads(self.rules) if self.rules else None
except JSONDecodeError:
@@ -392,10 +392,10 @@ class Document(Base):
return status
@property
def data_source_info_dict(self):
def data_source_info_dict(self) -> dict[str, Any] | None:
if self.data_source_info:
try:
data_source_info_dict = json.loads(self.data_source_info)
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
except JSONDecodeError:
data_source_info_dict = {}
@@ -403,10 +403,10 @@ class Document(Base):
return None
@property
def data_source_detail_dict(self):
def data_source_detail_dict(self) -> dict[str, Any]:
if self.data_source_info:
if self.data_source_type == "upload_file":
data_source_info_dict = json.loads(self.data_source_info)
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
@@ -425,7 +425,8 @@ class Document(Base):
}
}
elif self.data_source_type in {"notion_import", "website_crawl"}:
return json.loads(self.data_source_info)
result: dict[str, Any] = json.loads(self.data_source_info)
return result
return {}
@property
@@ -471,7 +472,7 @@ class Document(Base):
return self.updated_at
@property
def doc_metadata_details(self):
def doc_metadata_details(self) -> list[dict[str, Any]] | None:
if self.doc_metadata:
document_metadatas = (
db.session.query(DatasetMetadata)
@@ -481,9 +482,9 @@ class Document(Base):
)
.all()
)
metadata_list = []
metadata_list: list[dict[str, Any]] = []
for metadata in document_metadatas:
metadata_dict = {
metadata_dict: dict[str, Any] = {
"id": metadata.id,
"name": metadata.name,
"type": metadata.type,
@@ -497,13 +498,13 @@ class Document(Base):
return None
@property
def process_rule_dict(self):
if self.dataset_process_rule_id:
def process_rule_dict(self) -> dict[str, Any] | None:
if self.dataset_process_rule_id and self.dataset_process_rule:
return self.dataset_process_rule.to_dict()
return None
def get_built_in_fields(self):
built_in_fields = []
def get_built_in_fields(self) -> list[dict[str, Any]]:
built_in_fields: list[dict[str, Any]] = []
built_in_fields.append(
{
"id": "built-in",
@@ -546,7 +547,7 @@ class Document(Base):
)
return built_in_fields
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"tenant_id": self.tenant_id,
@@ -592,13 +593,13 @@ class Document(Base):
"data_source_info_dict": self.data_source_info_dict,
"average_segment_length": self.average_segment_length,
"dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
"dataset": self.dataset.to_dict() if self.dataset else None,
"dataset": None, # Dataset class doesn't have a to_dict method
"segment_count": self.segment_count,
"hit_count": self.hit_count,
}
@classmethod
def from_dict(cls, data: dict):
def from_dict(cls, data: dict[str, Any]):
return cls(
id=data.get("id"),
tenant_id=data.get("tenant_id"),
@@ -711,46 +712,48 @@ class DocumentSegment(Base):
)
@property
def child_chunks(self):
process_rule = self.document.dataset_process_rule
if process_rule.mode == "hierarchical":
rules = Rule(**process_rule.rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
else:
return []
else:
def child_chunks(self) -> list[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
return []
def get_child_chunks(self):
process_rule = self.document.dataset_process_rule
if process_rule.mode == "hierarchical":
rules = Rule(**process_rule.rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
else:
return []
else:
def get_child_chunks(self) -> list[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
return []
@property
def sign_content(self):
def sign_content(self) -> str:
return self.get_sign_content()
def get_sign_content(self):
signed_urls = []
def get_sign_content(self) -> str:
signed_urls: list[tuple[int, int, str]] = []
text = self.content
# For data before v0.10.0
@@ -890,17 +893,22 @@ class DatasetKeywordTable(Base):
)
@property
def keyword_table_dict(self):
def keyword_table_dict(self) -> dict[str, set[Any]] | None:
class SetDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
super().__init__(object_hook=self.object_hook, *args, **kwargs)
def __init__(self, *args: Any, **kwargs: Any) -> None:
def object_hook(dct: Any) -> Any:
if isinstance(dct, dict):
result: dict[str, Any] = {}
items = cast(dict[str, Any], dct).items()
for keyword, node_idxs in items:
if isinstance(node_idxs, list):
result[keyword] = set(cast(list[Any], node_idxs))
else:
result[keyword] = node_idxs
return result
return dct
def object_hook(self, dct):
if isinstance(dct, dict):
for keyword, node_idxs in dct.items():
if isinstance(node_idxs, list):
dct[keyword] = set(node_idxs)
return dct
super().__init__(object_hook=object_hook, *args, **kwargs)
# get dataset
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
@@ -1026,7 +1034,7 @@ class ExternalKnowledgeApis(Base):
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"tenant_id": self.tenant_id,
@@ -1039,14 +1047,14 @@ class ExternalKnowledgeApis(Base):
}
@property
def settings_dict(self):
def settings_dict(self) -> dict[str, Any] | None:
try:
return json.loads(self.settings) if self.settings else None
except JSONDecodeError:
return None
@property
def dataset_bindings(self):
def dataset_bindings(self) -> list[dict[str, Any]]:
external_knowledge_bindings = (
db.session.query(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
@@ -1054,7 +1062,7 @@ class ExternalKnowledgeApis(Base):
)
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
dataset_bindings = []
dataset_bindings: list[dict[str, Any]] = []
for dataset in datasets:
dataset_bindings.append({"id": dataset.id, "name": dataset.name})