[Chore/Refactor] Improve type annotations in models module (#25281)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
@@ -286,7 +286,7 @@ class DatasetProcessRule(Base):
|
||||
"segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
|
||||
}
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"dataset_id": self.dataset_id,
|
||||
@@ -295,7 +295,7 @@ class DatasetProcessRule(Base):
|
||||
}
|
||||
|
||||
@property
|
||||
def rules_dict(self):
|
||||
def rules_dict(self) -> dict[str, Any] | None:
|
||||
try:
|
||||
return json.loads(self.rules) if self.rules else None
|
||||
except JSONDecodeError:
|
||||
@@ -392,10 +392,10 @@ class Document(Base):
|
||||
return status
|
||||
|
||||
@property
|
||||
def data_source_info_dict(self):
|
||||
def data_source_info_dict(self) -> dict[str, Any] | None:
|
||||
if self.data_source_info:
|
||||
try:
|
||||
data_source_info_dict = json.loads(self.data_source_info)
|
||||
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
|
||||
except JSONDecodeError:
|
||||
data_source_info_dict = {}
|
||||
|
||||
@@ -403,10 +403,10 @@ class Document(Base):
|
||||
return None
|
||||
|
||||
@property
|
||||
def data_source_detail_dict(self):
|
||||
def data_source_detail_dict(self) -> dict[str, Any]:
|
||||
if self.data_source_info:
|
||||
if self.data_source_type == "upload_file":
|
||||
data_source_info_dict = json.loads(self.data_source_info)
|
||||
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
|
||||
@@ -425,7 +425,8 @@ class Document(Base):
|
||||
}
|
||||
}
|
||||
elif self.data_source_type in {"notion_import", "website_crawl"}:
|
||||
return json.loads(self.data_source_info)
|
||||
result: dict[str, Any] = json.loads(self.data_source_info)
|
||||
return result
|
||||
return {}
|
||||
|
||||
@property
|
||||
@@ -471,7 +472,7 @@ class Document(Base):
|
||||
return self.updated_at
|
||||
|
||||
@property
|
||||
def doc_metadata_details(self):
|
||||
def doc_metadata_details(self) -> list[dict[str, Any]] | None:
|
||||
if self.doc_metadata:
|
||||
document_metadatas = (
|
||||
db.session.query(DatasetMetadata)
|
||||
@@ -481,9 +482,9 @@ class Document(Base):
|
||||
)
|
||||
.all()
|
||||
)
|
||||
metadata_list = []
|
||||
metadata_list: list[dict[str, Any]] = []
|
||||
for metadata in document_metadatas:
|
||||
metadata_dict = {
|
||||
metadata_dict: dict[str, Any] = {
|
||||
"id": metadata.id,
|
||||
"name": metadata.name,
|
||||
"type": metadata.type,
|
||||
@@ -497,13 +498,13 @@ class Document(Base):
|
||||
return None
|
||||
|
||||
@property
|
||||
def process_rule_dict(self):
|
||||
if self.dataset_process_rule_id:
|
||||
def process_rule_dict(self) -> dict[str, Any] | None:
|
||||
if self.dataset_process_rule_id and self.dataset_process_rule:
|
||||
return self.dataset_process_rule.to_dict()
|
||||
return None
|
||||
|
||||
def get_built_in_fields(self):
|
||||
built_in_fields = []
|
||||
def get_built_in_fields(self) -> list[dict[str, Any]]:
|
||||
built_in_fields: list[dict[str, Any]] = []
|
||||
built_in_fields.append(
|
||||
{
|
||||
"id": "built-in",
|
||||
@@ -546,7 +547,7 @@ class Document(Base):
|
||||
)
|
||||
return built_in_fields
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
@@ -592,13 +593,13 @@ class Document(Base):
|
||||
"data_source_info_dict": self.data_source_info_dict,
|
||||
"average_segment_length": self.average_segment_length,
|
||||
"dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
|
||||
"dataset": self.dataset.to_dict() if self.dataset else None,
|
||||
"dataset": None, # Dataset class doesn't have a to_dict method
|
||||
"segment_count": self.segment_count,
|
||||
"hit_count": self.hit_count,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
def from_dict(cls, data: dict[str, Any]):
|
||||
return cls(
|
||||
id=data.get("id"),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
@@ -711,46 +712,48 @@ class DocumentSegment(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def child_chunks(self):
|
||||
process_rule = self.document.dataset_process_rule
|
||||
if process_rule.mode == "hierarchical":
|
||||
rules = Rule(**process_rule.rules_dict)
|
||||
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
return child_chunks or []
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
def child_chunks(self) -> list[Any]:
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
if process_rule and process_rule.mode == "hierarchical":
|
||||
rules_dict = process_rule.rules_dict
|
||||
if rules_dict:
|
||||
rules = Rule(**rules_dict)
|
||||
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
def get_child_chunks(self):
|
||||
process_rule = self.document.dataset_process_rule
|
||||
if process_rule.mode == "hierarchical":
|
||||
rules = Rule(**process_rule.rules_dict)
|
||||
if rules.parent_mode:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
return child_chunks or []
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
def get_child_chunks(self) -> list[Any]:
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
if process_rule and process_rule.mode == "hierarchical":
|
||||
rules_dict = process_rule.rules_dict
|
||||
if rules_dict:
|
||||
rules = Rule(**rules_dict)
|
||||
if rules.parent_mode:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
@property
|
||||
def sign_content(self):
|
||||
def sign_content(self) -> str:
|
||||
return self.get_sign_content()
|
||||
|
||||
def get_sign_content(self):
|
||||
signed_urls = []
|
||||
def get_sign_content(self) -> str:
|
||||
signed_urls: list[tuple[int, int, str]] = []
|
||||
text = self.content
|
||||
|
||||
# For data before v0.10.0
|
||||
@@ -890,17 +893,22 @@ class DatasetKeywordTable(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def keyword_table_dict(self):
|
||||
def keyword_table_dict(self) -> dict[str, set[Any]] | None:
|
||||
class SetDecoder(json.JSONDecoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(object_hook=self.object_hook, *args, **kwargs)
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
def object_hook(dct: Any) -> Any:
|
||||
if isinstance(dct, dict):
|
||||
result: dict[str, Any] = {}
|
||||
items = cast(dict[str, Any], dct).items()
|
||||
for keyword, node_idxs in items:
|
||||
if isinstance(node_idxs, list):
|
||||
result[keyword] = set(cast(list[Any], node_idxs))
|
||||
else:
|
||||
result[keyword] = node_idxs
|
||||
return result
|
||||
return dct
|
||||
|
||||
def object_hook(self, dct):
|
||||
if isinstance(dct, dict):
|
||||
for keyword, node_idxs in dct.items():
|
||||
if isinstance(node_idxs, list):
|
||||
dct[keyword] = set(node_idxs)
|
||||
return dct
|
||||
super().__init__(object_hook=object_hook, *args, **kwargs)
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
|
||||
@@ -1026,7 +1034,7 @@ class ExternalKnowledgeApis(Base):
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
@@ -1039,14 +1047,14 @@ class ExternalKnowledgeApis(Base):
|
||||
}
|
||||
|
||||
@property
|
||||
def settings_dict(self):
|
||||
def settings_dict(self) -> dict[str, Any] | None:
|
||||
try:
|
||||
return json.loads(self.settings) if self.settings else None
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def dataset_bindings(self):
|
||||
def dataset_bindings(self) -> list[dict[str, Any]]:
|
||||
external_knowledge_bindings = (
|
||||
db.session.query(ExternalKnowledgeBindings)
|
||||
.where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
|
||||
@@ -1054,7 +1062,7 @@ class ExternalKnowledgeApis(Base):
|
||||
)
|
||||
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
|
||||
datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
|
||||
dataset_bindings = []
|
||||
dataset_bindings: list[dict[str, Any]] = []
|
||||
for dataset in datasets:
|
||||
dataset_bindings.append({"id": dataset.id, "name": dataset.name})
|
||||
|
||||
|
Reference in New Issue
Block a user