feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -4,11 +4,11 @@ import uuid
from collections.abc import Mapping
from datetime import datetime
from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Literal, Optional
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin
from flask_login import UserMixin # type: ignore
from sqlalchemy import Float, func, text
from sqlalchemy.orm import Mapped, mapped_column
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from .workflow import Workflow
class DifySetup(db.Model):
class DifySetup(db.Model): # type: ignore[name-defined]
__tablename__ = "dify_setups"
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
@@ -63,7 +63,7 @@ class IconType(Enum):
EMOJI = "emoji"
class App(db.Model):
class App(db.Model): # type: ignore[name-defined]
__tablename__ = "apps"
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
@@ -86,7 +86,7 @@ class App(db.Model):
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
tracing = db.Column(db.Text, nullable=True)
max_active_requests = db.Column(db.Integer, nullable=True)
max_active_requests: Mapped[Optional[int]] = mapped_column(nullable=True)
created_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
@@ -154,7 +154,7 @@ class App(db.Model):
if self.mode == AppMode.CHAT.value and self.is_agent:
return AppMode.AGENT_CHAT.value
return self.mode
return str(self.mode)
@property
def deleted_tools(self) -> list:
@@ -219,7 +219,7 @@ class App(db.Model):
return tags or []
class AppModelConfig(db.Model):
class AppModelConfig(db.Model): # type: ignore[name-defined]
__tablename__ = "app_model_configs"
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id"))
@@ -322,7 +322,7 @@ class AppModelConfig(db.Model):
return json.loads(self.external_data_tools) if self.external_data_tools else []
@property
def user_input_form_list(self) -> dict:
def user_input_form_list(self) -> list[dict]:
return json.loads(self.user_input_form) if self.user_input_form else []
@property
@@ -344,7 +344,7 @@ class AppModelConfig(db.Model):
@property
def dataset_configs_dict(self) -> dict:
if self.dataset_configs:
dataset_configs = json.loads(self.dataset_configs)
dataset_configs: dict = json.loads(self.dataset_configs)
if "retrieval_model" not in dataset_configs:
return {"retrieval_model": "single"}
else:
@@ -466,7 +466,7 @@ class AppModelConfig(db.Model):
return new_app_model_config
class RecommendedApp(db.Model):
class RecommendedApp(db.Model): # type: ignore[name-defined]
__tablename__ = "recommended_apps"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
@@ -494,7 +494,7 @@ class RecommendedApp(db.Model):
return app
class InstalledApp(db.Model):
class InstalledApp(db.Model): # type: ignore[name-defined]
__tablename__ = "installed_apps"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="installed_app_pkey"),
@@ -523,7 +523,7 @@ class InstalledApp(db.Model):
return tenant
class Conversation(db.Model):
class Conversation(db.Model): # type: ignore[name-defined]
__tablename__ = "conversations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="conversation_pkey"),
@@ -602,6 +602,8 @@ class Conversation(db.Model):
@property
def model_config(self):
model_config = {}
app_model_config: Optional[AppModelConfig] = None
if self.mode == AppMode.ADVANCED_CHAT.value:
if self.override_model_configs:
override_model_configs = json.loads(self.override_model_configs)
@@ -613,6 +615,7 @@ class Conversation(db.Model):
if "model" in override_model_configs:
app_model_config = AppModelConfig()
app_model_config = app_model_config.from_model_config_dict(override_model_configs)
assert app_model_config is not None, "app model config not found"
model_config = app_model_config.to_dict()
else:
model_config["configs"] = override_model_configs
@@ -755,7 +758,7 @@ class Conversation(db.Model):
return self.override_model_configs is not None
class Message(db.Model):
class Message(db.Model): # type: ignore[name-defined]
__tablename__ = "messages"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_pkey"),
@@ -995,7 +998,7 @@ class Message(db.Model):
if not current_app:
raise ValueError(f"App {self.app_id} not found")
files: list[File] = []
files = []
for message_file in message_files:
if message_file.transfer_method == "local_file":
if message_file.upload_file_id is None:
@@ -1102,7 +1105,7 @@ class Message(db.Model):
)
class MessageFeedback(db.Model):
class MessageFeedback(db.Model): # type: ignore[name-defined]
__tablename__ = "message_feedbacks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
@@ -1129,7 +1132,7 @@ class MessageFeedback(db.Model):
return account
class MessageFile(db.Model):
class MessageFile(db.Model): # type: ignore[name-defined]
__tablename__ = "message_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_file_pkey"),
@@ -1170,7 +1173,7 @@ class MessageFile(db.Model):
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageAnnotation(db.Model):
class MessageAnnotation(db.Model): # type: ignore[name-defined]
__tablename__ = "message_annotations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
@@ -1201,7 +1204,7 @@ class MessageAnnotation(db.Model):
return account
class AppAnnotationHitHistory(db.Model):
class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
__tablename__ = "app_annotation_hit_histories"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@@ -1239,7 +1242,7 @@ class AppAnnotationHitHistory(db.Model):
return account
class AppAnnotationSetting(db.Model):
class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
__tablename__ = "app_annotation_settings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
@@ -1287,7 +1290,7 @@ class AppAnnotationSetting(db.Model):
return collection_binding_detail
class OperationLog(db.Model):
class OperationLog(db.Model): # type: ignore[name-defined]
__tablename__ = "operation_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="operation_log_pkey"),
@@ -1304,7 +1307,7 @@ class OperationLog(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class EndUser(UserMixin, db.Model):
class EndUser(UserMixin, db.Model): # type: ignore[name-defined]
__tablename__ = "end_users"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="end_user_pkey"),
@@ -1324,7 +1327,7 @@ class EndUser(UserMixin, db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class Site(db.Model):
class Site(db.Model): # type: ignore[name-defined]
__tablename__ = "sites"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="site_pkey"),
@@ -1381,7 +1384,7 @@ class Site(db.Model):
return dify_config.APP_WEB_URL or request.url_root.rstrip("/")
class ApiToken(db.Model):
class ApiToken(db.Model): # type: ignore[name-defined]
__tablename__ = "api_tokens"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_token_pkey"),
@@ -1408,7 +1411,7 @@ class ApiToken(db.Model):
return result
class UploadFile(db.Model):
class UploadFile(db.Model): # type: ignore[name-defined]
__tablename__ = "upload_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="upload_file_pkey"),
@@ -1470,7 +1473,7 @@ class UploadFile(db.Model):
self.source_url = source_url
class ApiRequest(db.Model):
class ApiRequest(db.Model): # type: ignore[name-defined]
__tablename__ = "api_requests"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_request_pkey"),
@@ -1487,7 +1490,7 @@ class ApiRequest(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageChain(db.Model):
class MessageChain(db.Model): # type: ignore[name-defined]
__tablename__ = "message_chains"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_chain_pkey"),
@@ -1502,7 +1505,7 @@ class MessageChain(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class MessageAgentThought(db.Model):
class MessageAgentThought(db.Model): # type: ignore[name-defined]
__tablename__ = "message_agent_thoughts"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
@@ -1542,7 +1545,7 @@ class MessageAgentThought(db.Model):
@property
def files(self) -> list:
if self.message_files:
return json.loads(self.message_files)
return cast(list[Any], json.loads(self.message_files))
else:
return []
@@ -1554,7 +1557,7 @@ class MessageAgentThought(db.Model):
def tool_labels(self) -> dict:
try:
if self.tool_labels_str:
return json.loads(self.tool_labels_str)
return cast(dict, json.loads(self.tool_labels_str))
else:
return {}
except Exception as e:
@@ -1564,7 +1567,7 @@ class MessageAgentThought(db.Model):
def tool_meta(self) -> dict:
try:
if self.tool_meta_str:
return json.loads(self.tool_meta_str)
return cast(dict, json.loads(self.tool_meta_str))
else:
return {}
except Exception as e:
@@ -1612,9 +1615,11 @@ class MessageAgentThought(db.Model):
except Exception as e:
if self.observation:
return dict.fromkeys(tools, self.observation)
else:
return {}
class DatasetRetrieverResource(db.Model):
class DatasetRetrieverResource(db.Model): # type: ignore[name-defined]
__tablename__ = "dataset_retriever_resources"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
@@ -1641,7 +1646,7 @@ class DatasetRetrieverResource(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class Tag(db.Model):
class Tag(db.Model): # type: ignore[name-defined]
__tablename__ = "tags"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tag_pkey"),
@@ -1659,7 +1664,7 @@ class Tag(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TagBinding(db.Model):
class TagBinding(db.Model): # type: ignore[name-defined]
__tablename__ = "tag_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
@@ -1675,7 +1680,7 @@ class TagBinding(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TraceAppConfig(db.Model):
class TraceAppConfig(db.Model): # type: ignore[name-defined]
__tablename__ = "trace_app_config"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),