feat: mypy for all type check (#10921)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import enum
|
||||
import json
|
||||
|
||||
from flask_login import UserMixin
|
||||
from flask_login import UserMixin # type: ignore
|
||||
from sqlalchemy import func
|
||||
|
||||
from .engine import db
|
||||
@@ -16,7 +16,7 @@ class AccountStatus(enum.StrEnum):
|
||||
CLOSED = "closed"
|
||||
|
||||
|
||||
class Account(UserMixin, db.Model):
|
||||
class Account(UserMixin, db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "accounts"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))
|
||||
|
||||
@@ -43,7 +43,8 @@ class Account(UserMixin, db.Model):
|
||||
|
||||
@property
|
||||
def current_tenant(self):
|
||||
return self._current_tenant
|
||||
# FIXME: fix the type error later, because the type is important maybe cause some bugs
|
||||
return self._current_tenant # type: ignore
|
||||
|
||||
@current_tenant.setter
|
||||
def current_tenant(self, value: "Tenant"):
|
||||
@@ -52,7 +53,8 @@ class Account(UserMixin, db.Model):
|
||||
if ta:
|
||||
tenant.current_role = ta.role
|
||||
else:
|
||||
tenant = None
|
||||
# FIXME: fix the type error later, because the type is important maybe cause some bugs
|
||||
tenant = None # type: ignore
|
||||
self._current_tenant = tenant
|
||||
|
||||
@property
|
||||
@@ -89,7 +91,7 @@ class Account(UserMixin, db.Model):
|
||||
return AccountStatus(status_str)
|
||||
|
||||
@classmethod
|
||||
def get_by_openid(cls, provider: str, open_id: str) -> db.Model:
|
||||
def get_by_openid(cls, provider: str, open_id: str):
|
||||
account_integrate = (
|
||||
db.session.query(AccountIntegrate)
|
||||
.filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
|
||||
@@ -134,7 +136,7 @@ class TenantAccountRole(enum.StrEnum):
|
||||
|
||||
@staticmethod
|
||||
def is_valid_role(role: str) -> bool:
|
||||
return role and role in {
|
||||
return role in {
|
||||
TenantAccountRole.OWNER,
|
||||
TenantAccountRole.ADMIN,
|
||||
TenantAccountRole.EDITOR,
|
||||
@@ -144,15 +146,15 @@ class TenantAccountRole(enum.StrEnum):
|
||||
|
||||
@staticmethod
|
||||
def is_privileged_role(role: str) -> bool:
|
||||
return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
|
||||
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
|
||||
|
||||
@staticmethod
|
||||
def is_admin_role(role: str) -> bool:
|
||||
return role and role == TenantAccountRole.ADMIN
|
||||
return role == TenantAccountRole.ADMIN
|
||||
|
||||
@staticmethod
|
||||
def is_non_owner_role(role: str) -> bool:
|
||||
return role and role in {
|
||||
return role in {
|
||||
TenantAccountRole.ADMIN,
|
||||
TenantAccountRole.EDITOR,
|
||||
TenantAccountRole.NORMAL,
|
||||
@@ -161,11 +163,11 @@ class TenantAccountRole(enum.StrEnum):
|
||||
|
||||
@staticmethod
|
||||
def is_editing_role(role: str) -> bool:
|
||||
return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
|
||||
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
|
||||
|
||||
@staticmethod
|
||||
def is_dataset_edit_role(role: str) -> bool:
|
||||
return role and role in {
|
||||
return role in {
|
||||
TenantAccountRole.OWNER,
|
||||
TenantAccountRole.ADMIN,
|
||||
TenantAccountRole.EDITOR,
|
||||
@@ -173,7 +175,7 @@ class TenantAccountRole(enum.StrEnum):
|
||||
}
|
||||
|
||||
|
||||
class Tenant(db.Model):
|
||||
class Tenant(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "tenants"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
||||
|
||||
@@ -209,7 +211,7 @@ class TenantAccountJoinRole(enum.Enum):
|
||||
DATASET_OPERATOR = "dataset_operator"
|
||||
|
||||
|
||||
class TenantAccountJoin(db.Model):
|
||||
class TenantAccountJoin(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "tenant_account_joins"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
|
||||
@@ -228,7 +230,7 @@ class TenantAccountJoin(db.Model):
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class AccountIntegrate(db.Model):
|
||||
class AccountIntegrate(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "account_integrates"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
|
||||
@@ -245,7 +247,7 @@ class AccountIntegrate(db.Model):
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class InvitationCode(db.Model):
|
||||
class InvitationCode(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "invitation_codes"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
|
||||
|
@@ -13,7 +13,7 @@ class APIBasedExtensionPoint(enum.Enum):
|
||||
APP_MODERATION_OUTPUT = "app.moderation.output"
|
||||
|
||||
|
||||
class APIBasedExtension(db.Model):
|
||||
class APIBasedExtension(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "api_based_extensions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),
|
||||
|
@@ -9,6 +9,7 @@ import pickle
|
||||
import re
|
||||
import time
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
@@ -29,7 +30,7 @@ class DatasetPermissionEnum(enum.StrEnum):
|
||||
PARTIAL_TEAM = "partial_members"
|
||||
|
||||
|
||||
class Dataset(db.Model):
|
||||
class Dataset(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "datasets"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_pkey"),
|
||||
@@ -200,7 +201,7 @@ class Dataset(db.Model):
|
||||
return f"Vector_index_{normalized_dataset_id}_Node"
|
||||
|
||||
|
||||
class DatasetProcessRule(db.Model):
|
||||
class DatasetProcessRule(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "dataset_process_rules"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
|
||||
@@ -216,7 +217,7 @@ class DatasetProcessRule(db.Model):
|
||||
|
||||
MODES = ["automatic", "custom"]
|
||||
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
|
||||
AUTOMATIC_RULES = {
|
||||
AUTOMATIC_RULES: dict[str, Any] = {
|
||||
"pre_processing_rules": [
|
||||
{"id": "remove_extra_spaces", "enabled": True},
|
||||
{"id": "remove_urls_emails", "enabled": False},
|
||||
@@ -242,7 +243,7 @@ class DatasetProcessRule(db.Model):
|
||||
return None
|
||||
|
||||
|
||||
class Document(db.Model):
|
||||
class Document(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "documents"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="document_pkey"),
|
||||
@@ -492,7 +493,7 @@ class Document(db.Model):
|
||||
)
|
||||
|
||||
|
||||
class DocumentSegment(db.Model):
|
||||
class DocumentSegment(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "document_segments"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
|
||||
@@ -604,7 +605,7 @@ class DocumentSegment(db.Model):
|
||||
return text
|
||||
|
||||
|
||||
class AppDatasetJoin(db.Model):
|
||||
class AppDatasetJoin(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "app_dataset_joins"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
|
||||
@@ -621,7 +622,7 @@ class AppDatasetJoin(db.Model):
|
||||
return db.session.get(App, self.app_id)
|
||||
|
||||
|
||||
class DatasetQuery(db.Model):
|
||||
class DatasetQuery(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "dataset_queries"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
|
||||
@@ -638,7 +639,7 @@ class DatasetQuery(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class DatasetKeywordTable(db.Model):
|
||||
class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "dataset_keyword_tables"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
|
||||
@@ -683,7 +684,7 @@ class DatasetKeywordTable(db.Model):
|
||||
return None
|
||||
|
||||
|
||||
class Embedding(db.Model):
|
||||
class Embedding(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "embeddings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="embedding_pkey"),
|
||||
@@ -704,10 +705,10 @@ class Embedding(db.Model):
|
||||
self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def get_embedding(self) -> list[float]:
|
||||
return pickle.loads(self.embedding)
|
||||
return cast(list[float], pickle.loads(self.embedding))
|
||||
|
||||
|
||||
class DatasetCollectionBinding(db.Model):
|
||||
class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "dataset_collection_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
|
||||
@@ -722,7 +723,7 @@ class DatasetCollectionBinding(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TidbAuthBinding(db.Model):
|
||||
class TidbAuthBinding(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "tidb_auth_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
|
||||
@@ -742,7 +743,7 @@ class TidbAuthBinding(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class Whitelist(db.Model):
|
||||
class Whitelist(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "whitelists"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
|
||||
@@ -754,7 +755,7 @@ class Whitelist(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class DatasetPermission(db.Model):
|
||||
class DatasetPermission(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "dataset_permissions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
|
||||
@@ -771,7 +772,7 @@ class DatasetPermission(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ExternalKnowledgeApis(db.Model):
|
||||
class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "external_knowledge_apis"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
|
||||
@@ -824,7 +825,7 @@ class ExternalKnowledgeApis(db.Model):
|
||||
return dataset_bindings
|
||||
|
||||
|
||||
class ExternalKnowledgeBindings(db.Model):
|
||||
class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "external_knowledge_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
|
||||
|
@@ -4,11 +4,11 @@ import uuid
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin
|
||||
from flask_login import UserMixin # type: ignore
|
||||
from sqlalchemy import Float, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from .workflow import Workflow
|
||||
|
||||
|
||||
class DifySetup(db.Model):
|
||||
class DifySetup(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "dify_setups"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||
|
||||
@@ -63,7 +63,7 @@ class IconType(Enum):
|
||||
EMOJI = "emoji"
|
||||
|
||||
|
||||
class App(db.Model):
|
||||
class App(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "apps"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
|
||||
|
||||
@@ -86,7 +86,7 @@ class App(db.Model):
|
||||
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
tracing = db.Column(db.Text, nullable=True)
|
||||
max_active_requests = db.Column(db.Integer, nullable=True)
|
||||
max_active_requests: Mapped[Optional[int]] = mapped_column(nullable=True)
|
||||
created_by = db.Column(StringUUID, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
@@ -154,7 +154,7 @@ class App(db.Model):
|
||||
if self.mode == AppMode.CHAT.value and self.is_agent:
|
||||
return AppMode.AGENT_CHAT.value
|
||||
|
||||
return self.mode
|
||||
return str(self.mode)
|
||||
|
||||
@property
|
||||
def deleted_tools(self) -> list:
|
||||
@@ -219,7 +219,7 @@ class App(db.Model):
|
||||
return tags or []
|
||||
|
||||
|
||||
class AppModelConfig(db.Model):
|
||||
class AppModelConfig(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "app_model_configs"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id"))
|
||||
|
||||
@@ -322,7 +322,7 @@ class AppModelConfig(db.Model):
|
||||
return json.loads(self.external_data_tools) if self.external_data_tools else []
|
||||
|
||||
@property
|
||||
def user_input_form_list(self) -> dict:
|
||||
def user_input_form_list(self) -> list[dict]:
|
||||
return json.loads(self.user_input_form) if self.user_input_form else []
|
||||
|
||||
@property
|
||||
@@ -344,7 +344,7 @@ class AppModelConfig(db.Model):
|
||||
@property
|
||||
def dataset_configs_dict(self) -> dict:
|
||||
if self.dataset_configs:
|
||||
dataset_configs = json.loads(self.dataset_configs)
|
||||
dataset_configs: dict = json.loads(self.dataset_configs)
|
||||
if "retrieval_model" not in dataset_configs:
|
||||
return {"retrieval_model": "single"}
|
||||
else:
|
||||
@@ -466,7 +466,7 @@ class AppModelConfig(db.Model):
|
||||
return new_app_model_config
|
||||
|
||||
|
||||
class RecommendedApp(db.Model):
|
||||
class RecommendedApp(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "recommended_apps"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
|
||||
@@ -494,7 +494,7 @@ class RecommendedApp(db.Model):
|
||||
return app
|
||||
|
||||
|
||||
class InstalledApp(db.Model):
|
||||
class InstalledApp(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "installed_apps"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="installed_app_pkey"),
|
||||
@@ -523,7 +523,7 @@ class InstalledApp(db.Model):
|
||||
return tenant
|
||||
|
||||
|
||||
class Conversation(db.Model):
|
||||
class Conversation(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "conversations"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="conversation_pkey"),
|
||||
@@ -602,6 +602,8 @@ class Conversation(db.Model):
|
||||
@property
|
||||
def model_config(self):
|
||||
model_config = {}
|
||||
app_model_config: Optional[AppModelConfig] = None
|
||||
|
||||
if self.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if self.override_model_configs:
|
||||
override_model_configs = json.loads(self.override_model_configs)
|
||||
@@ -613,6 +615,7 @@ class Conversation(db.Model):
|
||||
if "model" in override_model_configs:
|
||||
app_model_config = AppModelConfig()
|
||||
app_model_config = app_model_config.from_model_config_dict(override_model_configs)
|
||||
assert app_model_config is not None, "app model config not found"
|
||||
model_config = app_model_config.to_dict()
|
||||
else:
|
||||
model_config["configs"] = override_model_configs
|
||||
@@ -755,7 +758,7 @@ class Conversation(db.Model):
|
||||
return self.override_model_configs is not None
|
||||
|
||||
|
||||
class Message(db.Model):
|
||||
class Message(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "messages"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_pkey"),
|
||||
@@ -995,7 +998,7 @@ class Message(db.Model):
|
||||
if not current_app:
|
||||
raise ValueError(f"App {self.app_id} not found")
|
||||
|
||||
files: list[File] = []
|
||||
files = []
|
||||
for message_file in message_files:
|
||||
if message_file.transfer_method == "local_file":
|
||||
if message_file.upload_file_id is None:
|
||||
@@ -1102,7 +1105,7 @@ class Message(db.Model):
|
||||
)
|
||||
|
||||
|
||||
class MessageFeedback(db.Model):
|
||||
class MessageFeedback(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "message_feedbacks"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
|
||||
@@ -1129,7 +1132,7 @@ class MessageFeedback(db.Model):
|
||||
return account
|
||||
|
||||
|
||||
class MessageFile(db.Model):
|
||||
class MessageFile(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "message_files"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_file_pkey"),
|
||||
@@ -1170,7 +1173,7 @@ class MessageFile(db.Model):
|
||||
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class MessageAnnotation(db.Model):
|
||||
class MessageAnnotation(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "message_annotations"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
|
||||
@@ -1201,7 +1204,7 @@ class MessageAnnotation(db.Model):
|
||||
return account
|
||||
|
||||
|
||||
class AppAnnotationHitHistory(db.Model):
|
||||
class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "app_annotation_hit_histories"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
|
||||
@@ -1239,7 +1242,7 @@ class AppAnnotationHitHistory(db.Model):
|
||||
return account
|
||||
|
||||
|
||||
class AppAnnotationSetting(db.Model):
|
||||
class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "app_annotation_settings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
|
||||
@@ -1287,7 +1290,7 @@ class AppAnnotationSetting(db.Model):
|
||||
return collection_binding_detail
|
||||
|
||||
|
||||
class OperationLog(db.Model):
|
||||
class OperationLog(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "operation_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="operation_log_pkey"),
|
||||
@@ -1304,7 +1307,7 @@ class OperationLog(db.Model):
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class EndUser(UserMixin, db.Model):
|
||||
class EndUser(UserMixin, db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "end_users"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="end_user_pkey"),
|
||||
@@ -1324,7 +1327,7 @@ class EndUser(UserMixin, db.Model):
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class Site(db.Model):
|
||||
class Site(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "sites"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="site_pkey"),
|
||||
@@ -1381,7 +1384,7 @@ class Site(db.Model):
|
||||
return dify_config.APP_WEB_URL or request.url_root.rstrip("/")
|
||||
|
||||
|
||||
class ApiToken(db.Model):
|
||||
class ApiToken(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "api_tokens"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="api_token_pkey"),
|
||||
@@ -1408,7 +1411,7 @@ class ApiToken(db.Model):
|
||||
return result
|
||||
|
||||
|
||||
class UploadFile(db.Model):
|
||||
class UploadFile(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "upload_files"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="upload_file_pkey"),
|
||||
@@ -1470,7 +1473,7 @@ class UploadFile(db.Model):
|
||||
self.source_url = source_url
|
||||
|
||||
|
||||
class ApiRequest(db.Model):
|
||||
class ApiRequest(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "api_requests"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="api_request_pkey"),
|
||||
@@ -1487,7 +1490,7 @@ class ApiRequest(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class MessageChain(db.Model):
|
||||
class MessageChain(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "message_chains"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_chain_pkey"),
|
||||
@@ -1502,7 +1505,7 @@ class MessageChain(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class MessageAgentThought(db.Model):
|
||||
class MessageAgentThought(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "message_agent_thoughts"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
|
||||
@@ -1542,7 +1545,7 @@ class MessageAgentThought(db.Model):
|
||||
@property
|
||||
def files(self) -> list:
|
||||
if self.message_files:
|
||||
return json.loads(self.message_files)
|
||||
return cast(list[Any], json.loads(self.message_files))
|
||||
else:
|
||||
return []
|
||||
|
||||
@@ -1554,7 +1557,7 @@ class MessageAgentThought(db.Model):
|
||||
def tool_labels(self) -> dict:
|
||||
try:
|
||||
if self.tool_labels_str:
|
||||
return json.loads(self.tool_labels_str)
|
||||
return cast(dict, json.loads(self.tool_labels_str))
|
||||
else:
|
||||
return {}
|
||||
except Exception as e:
|
||||
@@ -1564,7 +1567,7 @@ class MessageAgentThought(db.Model):
|
||||
def tool_meta(self) -> dict:
|
||||
try:
|
||||
if self.tool_meta_str:
|
||||
return json.loads(self.tool_meta_str)
|
||||
return cast(dict, json.loads(self.tool_meta_str))
|
||||
else:
|
||||
return {}
|
||||
except Exception as e:
|
||||
@@ -1612,9 +1615,11 @@ class MessageAgentThought(db.Model):
|
||||
except Exception as e:
|
||||
if self.observation:
|
||||
return dict.fromkeys(tools, self.observation)
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class DatasetRetrieverResource(db.Model):
|
||||
class DatasetRetrieverResource(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "dataset_retriever_resources"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
|
||||
@@ -1641,7 +1646,7 @@ class DatasetRetrieverResource(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class Tag(db.Model):
|
||||
class Tag(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "tags"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tag_pkey"),
|
||||
@@ -1659,7 +1664,7 @@ class Tag(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TagBinding(db.Model):
|
||||
class TagBinding(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "tag_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
|
||||
@@ -1675,7 +1680,7 @@ class TagBinding(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TraceAppConfig(db.Model):
|
||||
class TraceAppConfig(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "trace_app_config"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),
|
||||
|
@@ -36,7 +36,7 @@ class ProviderQuotaType(Enum):
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class Provider(db.Model):
|
||||
class Provider(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
Provider model representing the API providers and their configurations.
|
||||
"""
|
||||
@@ -89,7 +89,7 @@ class Provider(db.Model):
|
||||
return self.is_valid and self.token_is_set
|
||||
|
||||
|
||||
class ProviderModel(db.Model):
|
||||
class ProviderModel(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
Provider model representing the API provider_models and their configurations.
|
||||
"""
|
||||
@@ -114,7 +114,7 @@ class ProviderModel(db.Model):
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TenantDefaultModel(db.Model):
|
||||
class TenantDefaultModel(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "tenant_default_models"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
|
||||
@@ -130,7 +130,7 @@ class TenantDefaultModel(db.Model):
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TenantPreferredModelProvider(db.Model):
|
||||
class TenantPreferredModelProvider(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "tenant_preferred_model_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
|
||||
@@ -145,7 +145,7 @@ class TenantPreferredModelProvider(db.Model):
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ProviderOrder(db.Model):
|
||||
class ProviderOrder(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "provider_orders"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="provider_order_pkey"),
|
||||
@@ -170,7 +170,7 @@ class ProviderOrder(db.Model):
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ProviderModelSetting(db.Model):
|
||||
class ProviderModelSetting(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
Provider model settings for record the model enabled status and load balancing status.
|
||||
"""
|
||||
@@ -192,7 +192,7 @@ class ProviderModelSetting(db.Model):
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class LoadBalancingModelConfig(db.Model):
|
||||
class LoadBalancingModelConfig(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
Configurations for load balancing models.
|
||||
"""
|
||||
|
@@ -7,7 +7,7 @@ from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class DataSourceOauthBinding(db.Model):
|
||||
class DataSourceOauthBinding(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "data_source_oauth_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),
|
||||
@@ -25,7 +25,7 @@ class DataSourceOauthBinding(db.Model):
|
||||
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
|
||||
|
||||
|
||||
class DataSourceApiKeyAuthBinding(db.Model):
|
||||
class DataSourceApiKeyAuthBinding(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "data_source_api_key_auth_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
|
||||
|
@@ -1,11 +1,11 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from celery import states
|
||||
from celery import states # type: ignore
|
||||
|
||||
from .engine import db
|
||||
|
||||
|
||||
class CeleryTask(db.Model):
|
||||
class CeleryTask(db.Model): # type: ignore[name-defined]
|
||||
"""Task result/status."""
|
||||
|
||||
__tablename__ = "celery_taskmeta"
|
||||
@@ -29,7 +29,7 @@ class CeleryTask(db.Model):
|
||||
queue = db.Column(db.String(155), nullable=True)
|
||||
|
||||
|
||||
class CeleryTaskSet(db.Model):
|
||||
class CeleryTaskSet(db.Model): # type: ignore[name-defined]
|
||||
"""TaskSet result."""
|
||||
|
||||
__tablename__ = "celery_tasksetmeta"
|
||||
|
@@ -14,7 +14,7 @@ from .model import Account, App, Tenant
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class BuiltinToolProvider(db.Model):
|
||||
class BuiltinToolProvider(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
This table stores the tool provider information for built-in tools for each tenant.
|
||||
"""
|
||||
@@ -41,10 +41,10 @@ class BuiltinToolProvider(db.Model):
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict:
|
||||
return json.loads(self.encrypted_credentials)
|
||||
return dict(json.loads(self.encrypted_credentials))
|
||||
|
||||
|
||||
class PublishedAppTool(db.Model):
|
||||
class PublishedAppTool(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
The table stores the apps published as a tool for each person.
|
||||
"""
|
||||
@@ -86,7 +86,7 @@ class PublishedAppTool(db.Model):
|
||||
return db.session.query(App).filter(App.id == self.app_id).first()
|
||||
|
||||
|
||||
class ApiToolProvider(db.Model):
|
||||
class ApiToolProvider(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
The table stores the api providers.
|
||||
"""
|
||||
@@ -133,7 +133,7 @@ class ApiToolProvider(db.Model):
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict:
|
||||
return json.loads(self.credentials_str)
|
||||
return dict(json.loads(self.credentials_str))
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
@@ -144,7 +144,7 @@ class ApiToolProvider(db.Model):
|
||||
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
|
||||
|
||||
class ToolLabelBinding(db.Model):
|
||||
class ToolLabelBinding(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
The table stores the labels for tools.
|
||||
"""
|
||||
@@ -164,7 +164,7 @@ class ToolLabelBinding(db.Model):
|
||||
label_name = db.Column(db.String(40), nullable=False)
|
||||
|
||||
|
||||
class WorkflowToolProvider(db.Model):
|
||||
class WorkflowToolProvider(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
The table stores the workflow providers.
|
||||
"""
|
||||
@@ -218,7 +218,7 @@ class WorkflowToolProvider(db.Model):
|
||||
return db.session.query(App).filter(App.id == self.app_id).first()
|
||||
|
||||
|
||||
class ToolModelInvoke(db.Model):
|
||||
class ToolModelInvoke(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
store the invoke logs from tool invoke
|
||||
"""
|
||||
@@ -255,7 +255,7 @@ class ToolModelInvoke(db.Model):
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ToolConversationVariables(db.Model):
|
||||
class ToolConversationVariables(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
store the conversation variables from tool invoke
|
||||
"""
|
||||
@@ -283,10 +283,10 @@ class ToolConversationVariables(db.Model):
|
||||
|
||||
@property
|
||||
def variables(self) -> dict:
|
||||
return json.loads(self.variables_str)
|
||||
return dict(json.loads(self.variables_str))
|
||||
|
||||
|
||||
class ToolFile(db.Model):
|
||||
class ToolFile(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "tool_files"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_file_pkey"),
|
||||
|
@@ -6,7 +6,7 @@ from .model import Message
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class SavedMessage(db.Model):
|
||||
class SavedMessage(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "saved_messages"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="saved_message_pkey"),
|
||||
@@ -25,7 +25,7 @@ class SavedMessage(db.Model):
|
||||
return db.session.query(Message).filter(Message.id == self.message_id).first()
|
||||
|
||||
|
||||
class PinnedConversation(db.Model):
|
||||
class PinnedConversation(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "pinned_conversations"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
|
||||
|
@@ -2,7 +2,7 @@ import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import func
|
||||
@@ -20,6 +20,9 @@ from .account import Account
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import AppMode, Message
|
||||
|
||||
|
||||
class WorkflowType(Enum):
|
||||
"""
|
||||
@@ -56,7 +59,7 @@ class WorkflowType(Enum):
|
||||
return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT
|
||||
|
||||
|
||||
class Workflow(db.Model):
|
||||
class Workflow(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
Workflow, for `Workflow App` and `Chat App workflow mode`.
|
||||
|
||||
@@ -182,7 +185,7 @@ class Workflow(db.Model):
|
||||
self._features = value
|
||||
|
||||
@property
|
||||
def features_dict(self) -> Mapping[str, Any]:
|
||||
def features_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.features) if self.features else {}
|
||||
|
||||
def user_input_form(self, to_old_structure: bool = False) -> list:
|
||||
@@ -199,7 +202,7 @@ class Workflow(db.Model):
|
||||
return []
|
||||
|
||||
# get user_input_form from start node
|
||||
variables = start_node.get("data", {}).get("variables", [])
|
||||
variables: list[Any] = start_node.get("data", {}).get("variables", [])
|
||||
|
||||
if to_old_structure:
|
||||
old_structure_variables = []
|
||||
@@ -344,7 +347,7 @@ class WorkflowRunStatus(StrEnum):
|
||||
raise ValueError(f"invalid workflow run status value {value}")
|
||||
|
||||
|
||||
class WorkflowRun(db.Model):
|
||||
class WorkflowRun(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
Workflow Run
|
||||
|
||||
@@ -546,7 +549,7 @@ class WorkflowNodeExecutionStatus(Enum):
|
||||
raise ValueError(f"invalid workflow node execution status value {value}")
|
||||
|
||||
|
||||
class WorkflowNodeExecution(db.Model):
|
||||
class WorkflowNodeExecution(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
Workflow Node Execution
|
||||
|
||||
@@ -712,7 +715,7 @@ class WorkflowAppLogCreatedFrom(Enum):
|
||||
raise ValueError(f"invalid workflow app log created from value {value}")
|
||||
|
||||
|
||||
class WorkflowAppLog(db.Model):
|
||||
class WorkflowAppLog(db.Model): # type: ignore[name-defined]
|
||||
"""
|
||||
Workflow App execution log, excluding workflow debugging records.
|
||||
|
||||
@@ -774,7 +777,7 @@ class WorkflowAppLog(db.Model):
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
||||
|
||||
|
||||
class ConversationVariable(db.Model):
|
||||
class ConversationVariable(db.Model): # type: ignore[name-defined]
|
||||
__tablename__ = "workflow_conversation_variables"
|
||||
|
||||
id: Mapped[str] = db.Column(StringUUID, primary_key=True)
|
||||
|
Reference in New Issue
Block a user