Use typing.Literal to replace str places (#24099)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Zhehao Peng
2025-08-18 06:34:13 -07:00
committed by GitHub
parent 670d479e32
commit c0702aacac
8 changed files with 34 additions and 26 deletions

View File

@@ -1,3 +1,5 @@
from typing import Literal
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal, marshal_with, reqparse from flask_restful import Resource, marshal, marshal_with, reqparse
@@ -24,7 +26,7 @@ class AnnotationReplyActionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
def post(self, app_id, action): def post(self, app_id, action: Literal["enable", "disable"]):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@@ -38,8 +40,6 @@ class AnnotationReplyActionApi(Resource):
result = AppAnnotationService.enable_app_annotation(args, app_id) result = AppAnnotationService.enable_app_annotation(args, app_id)
elif action == "disable": elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id) result = AppAnnotationService.disable_app_annotation(app_id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200 return result, 200

View File

@@ -1,6 +1,6 @@
import logging import logging
from argparse import ArgumentTypeError from argparse import ArgumentTypeError
from typing import cast from typing import Literal, cast
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
@@ -758,7 +758,7 @@ class DocumentProcessingApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action): def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
@@ -784,8 +784,6 @@ class DocumentProcessingApi(DocumentResource):
document.paused_at = None document.paused_at = None
document.is_paused = False document.is_paused = False
db.session.commit() db.session.commit()
else:
raise InvalidActionError()
return {"result": "success"}, 200 return {"result": "success"}, 200
@@ -840,7 +838,7 @@ class DocumentStatusApi(DocumentResource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action): def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if dataset is None: if dataset is None:

View File

@@ -1,3 +1,5 @@
from typing import Literal
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@@ -100,7 +102,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def post(self, dataset_id, action): def post(self, dataset_id, action: Literal["enable", "disable"]):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:

View File

@@ -1,3 +1,5 @@
from typing import Literal
from flask import request from flask import request
from flask_restful import Resource, marshal, marshal_with, reqparse from flask_restful import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@@ -15,7 +17,7 @@ from services.annotation_service import AppAnnotationService
class AnnotationReplyActionApi(Resource): class AnnotationReplyActionApi(Resource):
@validate_app_token @validate_app_token
def post(self, app_model: App, action): def post(self, app_model: App, action: Literal["enable", "disable"]):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json") parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json") parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
@@ -25,8 +27,6 @@ class AnnotationReplyActionApi(Resource):
result = AppAnnotationService.enable_app_annotation(args, app_model.id) result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable": elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id) result = AppAnnotationService.disable_app_annotation(app_model.id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200 return result, 200

View File

@@ -1,3 +1,5 @@
from typing import Literal
from flask import request from flask import request
from flask_restful import marshal, marshal_with, reqparse from flask_restful import marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@@ -358,14 +360,14 @@ class DatasetApi(DatasetApiResource):
class DocumentStatusApi(DatasetApiResource): class DocumentStatusApi(DatasetApiResource):
"""Resource for batch document status operations.""" """Resource for batch document status operations."""
def patch(self, tenant_id, dataset_id, action): def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
""" """
Batch update document status. Batch update document status.
Args: Args:
tenant_id: tenant id tenant_id: tenant id
dataset_id: dataset id dataset_id: dataset id
action: action to perform (enable, disable, archive, un_archive) action: action to perform (Literal["enable", "disable", "archive", "un_archive"])
Returns: Returns:
dict: A dictionary with a key 'result' and a value 'success' dict: A dictionary with a key 'result' and a value 'success'

View File

@@ -1,3 +1,5 @@
from typing import Literal
from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from flask_restful import marshal, reqparse from flask_restful import marshal, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@@ -77,7 +79,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, action): def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:

View File

@@ -6,7 +6,7 @@ import secrets
import time import time
import uuid import uuid
from collections import Counter from collections import Counter
from typing import Any, Optional from typing import Any, Literal, Optional
from flask_login import current_user from flask_login import current_user
from sqlalchemy import func, select from sqlalchemy import func, select
@@ -51,7 +51,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
RetrievalModel, RetrievalModel,
SegmentUpdateArgs, SegmentUpdateArgs,
) )
from services.errors.account import InvalidActionError, NoPermissionError from services.errors.account import NoPermissionError
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.dataset import DatasetNameDuplicateError from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError from services.errors.document import DocumentIndexingError
@@ -1800,14 +1800,16 @@ class DocumentService:
raise ValueError("Process rule segmentation max_tokens is invalid") raise ValueError("Process rule segmentation max_tokens is invalid")
@staticmethod @staticmethod
def batch_update_document_status(dataset: Dataset, document_ids: list[str], action: str, user): def batch_update_document_status(
dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user
):
""" """
Batch update document status. Batch update document status.
Args: Args:
dataset (Dataset): The dataset object dataset (Dataset): The dataset object
document_ids (list[str]): List of document IDs to update document_ids (list[str]): List of document IDs to update
action (str): Action to perform (enable, disable, archive, un_archive) action (Literal["enable", "disable", "archive", "un_archive"]): Action to perform
user: Current user performing the action user: Current user performing the action
Raises: Raises:
@@ -1890,9 +1892,10 @@ class DocumentService:
raise propagation_error raise propagation_error
@staticmethod @staticmethod
def _prepare_document_status_update(document, action: str, user): def _prepare_document_status_update(
""" document: Document, action: Literal["enable", "disable", "archive", "un_archive"], user
Prepare document status update information. ):
"""Prepare document status update information.
Args: Args:
document: Document object to update document: Document object to update
@@ -2355,7 +2358,9 @@ class SegmentService:
db.session.commit() db.session.commit()
@classmethod @classmethod
def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document): def update_segments_status(
cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document
):
# Check if segment_ids is not empty to avoid WHERE false condition # Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0: if not segment_ids or len(segment_ids) == 0:
return return
@@ -2413,8 +2418,6 @@ class SegmentService:
db.session.commit() db.session.commit()
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
else:
raise InvalidActionError()
@classmethod @classmethod
def create_child_chunk( def create_child_chunk(

View File

@@ -1,5 +1,6 @@
import logging import logging
import time import time
from typing import Literal
import click import click
from celery import shared_task # type: ignore from celery import shared_task # type: ignore
@@ -13,7 +14,7 @@ from models.dataset import Document as DatasetDocument
@shared_task(queue="dataset") @shared_task(queue="dataset")
def deal_dataset_vector_index_task(dataset_id: str, action: str): def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]):
""" """
Async deal dataset from index Async deal dataset from index
:param dataset_id: dataset_id :param dataset_id: dataset_id