feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -14,7 +14,7 @@ from models.dataset import Document
@document_index_created.connect
def handle(sender, **kwargs):
dataset_id = sender
document_ids = kwargs.get("document_ids")
document_ids = kwargs.get("document_ids", [])
documents = []
start_at = time.perf_counter()
for document_id in document_ids:

View File

@@ -8,18 +8,19 @@ def handle(sender, **kwargs):
"""Create site record when an app is created."""
app = sender
account = kwargs.get("account")
site = Site(
app_id=app.id,
title=app.name,
icon_type=app.icon_type,
icon=app.icon,
icon_background=app.icon_background,
default_language=account.interface_language,
customize_token_strategy="not_allow",
code=Site.generate_code(16),
created_by=app.created_by,
updated_by=app.updated_by,
)
if account is not None:
site = Site(
app_id=app.id,
title=app.name,
icon_type=app.icon_type,
icon=app.icon,
icon_background=app.icon_background,
default_language=account.interface_language,
customize_token_strategy="not_allow",
code=Site.generate_code(16),
created_by=app.created_by,
updated_by=app.updated_by,
)
db.session.add(site)
db.session.commit()
db.session.add(site)
db.session.commit()

View File

@@ -44,7 +44,7 @@ def handle(sender, **kwargs):
else:
used_quota = 1
if used_quota is not None:
if used_quota is not None and system_configuration.current_quota_type is not None:
db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
Provider.provider_name == model_config.provider,

View File

@@ -8,7 +8,10 @@ from events.app_event import app_draft_workflow_was_synced
@app_draft_workflow_was_synced.connect
def handle(sender, **kwargs):
app = sender
for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []):
synced_draft_workflow = kwargs.get("synced_draft_workflow")
if synced_draft_workflow is None:
return
for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
try:
tool_entity = ToolEntity(**node_data["data"])

View File

@@ -8,16 +8,18 @@ from models.model import AppModelConfig
def handle(sender, **kwargs):
app = sender
app_model_config = kwargs.get("app_model_config")
if app_model_config is None:
return
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
removed_dataset_ids = []
removed_dataset_ids: set[int] = set()
if not app_dataset_joins:
added_dataset_ids = dataset_ids
else:
old_dataset_ids = set()
old_dataset_ids: set[int] = set()
old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
added_dataset_ids = dataset_ids - old_dataset_ids
@@ -37,8 +39,8 @@ def handle(sender, **kwargs):
db.session.commit()
def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set:
dataset_ids = set()
def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[int]:
dataset_ids: set[int] = set()
if not app_model_config:
return dataset_ids

View File

@@ -17,11 +17,11 @@ def handle(sender, **kwargs):
dataset_ids = get_dataset_ids_from_workflow(published_workflow)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
removed_dataset_ids = []
removed_dataset_ids: set[int] = set()
if not app_dataset_joins:
added_dataset_ids = dataset_ids
else:
old_dataset_ids = set()
old_dataset_ids: set[int] = set()
old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
added_dataset_ids = dataset_ids - old_dataset_ids
@@ -41,8 +41,8 @@ def handle(sender, **kwargs):
db.session.commit()
def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set:
dataset_ids = set()
def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[int]:
dataset_ids: set[int] = set()
graph = published_workflow.graph_dict
if not graph:
return dataset_ids
@@ -60,7 +60,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set:
for node in knowledge_retrieval_nodes:
try:
node_data = KnowledgeRetrievalNodeData(**node.get("data", {}))
dataset_ids.update(node_data.dataset_ids)
dataset_ids.update(int(dataset_id) for dataset_id in node_data.dataset_ids)
except Exception as e:
continue