[Chore/Refactor] Improve type annotations in models module (#25281)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
-LAN-
2025-09-08 09:42:27 +08:00
committed by GitHub
parent e1f871fefe
commit 9b8a03b53b
23 changed files with 332 additions and 251 deletions

View File

@@ -1,29 +1,34 @@
import enum
from typing import Generic, TypeVar
import uuid
from typing import Any, Generic, TypeVar
from sqlalchemy import CHAR, VARCHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.type_api import TypeEngine
class StringUUID(TypeDecorator):
class StringUUID(TypeDecorator[uuid.UUID | str | None]):
impl = CHAR
cache_ok = True
def process_bind_param(self, value, dialect):
def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
elif dialect.name == "postgresql":
return str(value)
else:
return value.hex
if isinstance(value, uuid.UUID):
return value.hex
return value
def load_dialect_impl(self, dialect):
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(36))
def process_result_value(self, value, dialect):
def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
return str(value)
@@ -32,7 +37,7 @@ class StringUUID(TypeDecorator):
_E = TypeVar("_E", bound=enum.StrEnum)
class EnumText(TypeDecorator, Generic[_E]):
class EnumText(TypeDecorator[_E | None], Generic[_E]):
impl = VARCHAR
cache_ok = True
@@ -50,28 +55,25 @@ class EnumText(TypeDecorator, Generic[_E]):
# leave some rooms for future longer enum values.
self._length = max(max_enum_value_len, 20)
def process_bind_param(self, value: _E | str | None, dialect):
def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
if isinstance(value, self._enum_class):
return value.value
elif isinstance(value, str):
self._enum_class(value)
return value
else:
raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")
# Since _E is bound to StrEnum which inherits from str, at this point value must be str
self._enum_class(value)
return value
def load_dialect_impl(self, dialect):
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
return dialect.type_descriptor(VARCHAR(self._length))
def process_result_value(self, value, dialect) -> _E | None:
def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
if value is None:
return value
if not isinstance(value, str):
raise TypeError(f"expected str, got {type(value)}")
# Type annotation guarantees value is str at this point
return self._enum_class(value)
def compare_values(self, x, y):
def compare_values(self, x: _E | None, y: _E | None) -> bool:
if x is None or y is None:
return x is y
return x == y