From c0702aacacfa12c115237aaafce1b2efc1408e8b Mon Sep 17 00:00:00 2001 From: Zhehao Peng <32246435+Zhehao-P@users.noreply.github.com> Date: Mon, 18 Aug 2025 06:34:13 -0700 Subject: [PATCH] Use typing.Literal to replace str places (#24099) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/annotation.py | 6 ++--- .../console/datasets/datasets_document.py | 8 +++---- api/controllers/console/datasets/metadata.py | 4 +++- api/controllers/service_api/app/annotation.py | 6 ++--- .../service_api/dataset/dataset.py | 6 +++-- .../service_api/dataset/metadata.py | 4 +++- api/services/dataset_service.py | 23 +++++++++++-------- api/tasks/deal_dataset_vector_index_task.py | 3 ++- 8 files changed, 34 insertions(+), 26 deletions(-) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 493a9a52e..2caa908d4 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask import request from flask_login import current_user from flask_restful import Resource, marshal, marshal_with, reqparse @@ -24,7 +26,7 @@ class AnnotationReplyActionApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") - def post(self, app_id, action): + def post(self, app_id, action: Literal["enable", "disable"]): if not current_user.is_editor: raise Forbidden() @@ -38,8 +40,6 @@ class AnnotationReplyActionApi(Resource): result = AppAnnotationService.enable_app_annotation(args, app_id) elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_id) - else: - raise ValueError("Unsupported annotation reply action") return result, 200 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 4e0955bd4..413b018ba 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,6 +1,6 @@ import logging from argparse import ArgumentTypeError -from typing import cast +from typing import Literal, cast from flask import request from flask_login import current_user @@ -758,7 +758,7 @@ class DocumentProcessingApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def patch(self, dataset_id, document_id, action): + def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]): dataset_id = str(dataset_id) document_id = str(document_id) document = self.get_document(dataset_id, document_id) @@ -784,8 +784,6 @@ class DocumentProcessingApi(DocumentResource): document.paused_at = None document.is_paused = False db.session.commit() - else: - raise InvalidActionError() return {"result": "success"}, 200 @@ -840,7 +838,7 @@ class DocumentStatusApi(DocumentResource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") - def patch(self, dataset_id, action): + def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 65f76fb40..1b5570285 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import NotFound @@ -100,7 +102,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): @login_required @account_initialization_required @enterprise_license_required - def post(self, dataset_id, action): + def post(self, dataset_id, action: Literal["enable", "disable"]): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 9b22c535f..23446bb70 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask import request from flask_restful import Resource, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden @@ -15,7 +17,7 @@ from services.annotation_service import AppAnnotationService class AnnotationReplyActionApi(Resource): @validate_app_token - def post(self, app_model: App, action): + def post(self, app_model: App, action: Literal["enable", "disable"]): parser = reqparse.RequestParser() parser.add_argument("score_threshold", required=True, type=float, location="json") parser.add_argument("embedding_provider_name", required=True, type=str, location="json") @@ -25,8 +27,6 @@ class AnnotationReplyActionApi(Resource): result = AppAnnotationService.enable_app_annotation(args, app_model.id) elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_model.id) - else: - raise ValueError("Unsupported annotation reply action") return result, 200 diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 29eef4125..35b1efeff 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask import request from flask_restful import marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound @@ -358,14 +360,14 @@ class DatasetApi(DatasetApiResource): class DocumentStatusApi(DatasetApiResource): """Resource for batch document status operations.""" - def patch(self, tenant_id, dataset_id, action): + def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): """ Batch update document status. Args: tenant_id: tenant id dataset_id: dataset id - action: action to perform (enable, disable, archive, un_archive) + action: action to perform (Literal["enable", "disable", "archive", "un_archive"]) Returns: dict: A dictionary with a key 'result' and a value 'success' diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 6ba818c5f..75a0b1828 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask_login import current_user # type: ignore from flask_restful import marshal, reqparse from werkzeug.exceptions import NotFound @@ -77,7 +79,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, action): + def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 8934608da..6ddda4c0c 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,7 @@ import secrets import time import uuid from collections import Counter -from typing import Any, Optional +from typing import Any, Literal, Optional from flask_login import current_user from sqlalchemy import func, select @@ -51,7 +51,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( RetrievalModel, SegmentUpdateArgs, ) -from services.errors.account import InvalidActionError, NoPermissionError +from services.errors.account import NoPermissionError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError @@ -1800,14 +1800,16 @@ class DocumentService: raise ValueError("Process rule segmentation max_tokens is invalid") @staticmethod - def batch_update_document_status(dataset: Dataset, document_ids: list[str], action: str, user): + def batch_update_document_status( + dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user + ): """ Batch update document status. Args: dataset (Dataset): The dataset object document_ids (list[str]): List of document IDs to update - action (str): Action to perform (enable, disable, archive, un_archive) + action (Literal["enable", "disable", "archive", "un_archive"]): Action to perform user: Current user performing the action Raises: @@ -1890,9 +1892,10 @@ class DocumentService: raise propagation_error @staticmethod - def _prepare_document_status_update(document, action: str, user): - """ - Prepare document status update information. + def _prepare_document_status_update( + document: Document, action: Literal["enable", "disable", "archive", "un_archive"], user + ): + """Prepare document status update information. Args: document: Document object to update @@ -2355,7 +2358,9 @@ class SegmentService: db.session.commit() @classmethod - def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document): + def update_segments_status( + cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document + ): # Check if segment_ids is not empty to avoid WHERE false condition if not segment_ids or len(segment_ids) == 0: return @@ -2413,8 +2418,6 @@ class SegmentService: db.session.commit() disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) - else: - raise InvalidActionError() @classmethod def create_child_chunk( diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 8c4c1876a..5ab377c23 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -1,5 +1,6 @@ import logging import time +from typing import Literal import click from celery import shared_task # type: ignore @@ -13,7 +14,7 @@ from models.dataset import Document as DatasetDocument @shared_task(queue="dataset") -def deal_dataset_vector_index_task(dataset_id: str, action: str): +def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]): """ Async deal dataset from index :param dataset_id: dataset_id