diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 6962790ca..df9de825d 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -1,4 +1,4 @@ -from flask_restx import fields +from flask_restx import Api, Namespace, fields from libs.helper import AppIconUrlField @@ -10,6 +10,12 @@ parameters__system_parameters = { "workflow_file_upload_limit": fields.Integer, } + +def build_system_parameters_model(api_or_ns: Api | Namespace): + """Build the system parameters model for the API or Namespace.""" + return api_or_ns.model("SystemParameters", parameters__system_parameters) + + parameters_fields = { "opening_statement": fields.String, "suggested_questions": fields.Raw, @@ -25,6 +31,14 @@ parameters_fields = { "system_parameters": fields.Nested(parameters__system_parameters), } + +def build_parameters_model(api_or_ns: Api | Namespace): + """Build the parameters model for the API or Namespace.""" + copied_fields = parameters_fields.copy() + copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns)) + return api_or_ns.model("Parameters", copied_fields) + + site_fields = { "title": fields.String, "chat_color_theme": fields.String, @@ -41,3 +55,8 @@ site_fields = { "show_workflow_steps": fields.Boolean, "use_icon_as_answer_icon": fields.Boolean, } + + +def build_site_model(api_or_ns: Api | Namespace): + """Build the site model for the API or Namespace.""" + return api_or_ns.model("Site", site_fields) diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index a26b13ca4..c45e7dbb2 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -5,7 +5,7 @@ from werkzeug.exceptions import Forbidden from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required -from fields.tag_fields import tag_fields +from fields.tag_fields import dataset_tag_fields from libs.login import login_required from models.model import Tag from services.tag_service import TagService @@ -21,7 +21,7 @@ class TagListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(tag_fields) + @marshal_with(dataset_tag_fields) def get(self): tag_type = request.args.get("type", type=str, default="") keyword = request.args.get("keyword", default=None, type=str) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index b26f29d98..aaa3c8f9a 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -1,11 +1,23 @@ from flask import Blueprint +from flask_restx import Namespace from libs.external_api import ExternalApi bp = Blueprint("service_api", __name__, url_prefix="/v1") -api = ExternalApi(bp) + +api = ExternalApi( + bp, + version="1.0", + title="Service API", + description="API for application services", + doc="/docs", # Enable Swagger UI at /v1/docs +) + +service_api_ns = Namespace("service_api", description="Service operations") from . import index from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow from .dataset import dataset, document, hit_testing, metadata, segment, upload_file from .workspace import models + +api.add_namespace(service_api_ns) diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 5fde35e98..6bc94af8c 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,28 +1,51 @@ from typing import Literal from flask import request -from flask_restx import Resource, marshal, marshal_with, reqparse +from flask_restx import Api, Namespace, Resource, fields, reqparse +from flask_restx.api import HTTPStatus from werkzeug.exceptions import Forbidden -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client -from fields.annotation_fields import ( - annotation_fields, -) +from fields.annotation_fields import annotation_fields, build_annotation_model from libs.login import current_user from models.model import App from services.annotation_service import AppAnnotationService +# Define parsers for annotation API +annotation_create_parser = reqparse.RequestParser() +annotation_create_parser.add_argument("question", required=True, type=str, location="json", help="Annotation question") +annotation_create_parser.add_argument("answer", required=True, type=str, location="json", help="Annotation answer") +annotation_reply_action_parser = reqparse.RequestParser() +annotation_reply_action_parser.add_argument( + "score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching" +) +annotation_reply_action_parser.add_argument( + "embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name" +) +annotation_reply_action_parser.add_argument( + "embedding_model_name", required=True, type=str, location="json", help="Embedding model name" +) + + +@service_api_ns.route("/apps/annotation-reply/") class AnnotationReplyActionApi(Resource): + @service_api_ns.expect(annotation_reply_action_parser) + @service_api_ns.doc("annotation_reply_action") + @service_api_ns.doc(description="Enable or disable annotation reply feature") + @service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"}) + @service_api_ns.doc( + responses={ + 200: "Action completed successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token def post(self, app_model: App, action: Literal["enable", "disable"]): - parser = reqparse.RequestParser() - parser.add_argument("score_threshold", required=True, type=float, location="json") - parser.add_argument("embedding_provider_name", required=True, type=str, location="json") - parser.add_argument("embedding_model_name", required=True, type=str, location="json") - args = parser.parse_args() + """Enable or disable annotation reply feature.""" + args = annotation_reply_action_parser.parse_args() if action == "enable": result = AppAnnotationService.enable_app_annotation(args, app_model.id) elif action == "disable": @@ -30,9 +53,21 @@ class AnnotationReplyActionApi(Resource): return result, 200 +@service_api_ns.route("/apps/annotation-reply//status/") class AnnotationReplyActionStatusApi(Resource): + @service_api_ns.doc("get_annotation_reply_action_status") + @service_api_ns.doc(description="Get the status of an annotation reply action job") + @service_api_ns.doc(params={"action": "Action type", "job_id": "Job ID"}) + @service_api_ns.doc( + responses={ + 200: "Job status retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Job not found", + } + ) @validate_app_token def get(self, app_model: App, job_id, action): + """Get the status of an annotation reply action job.""" job_id = str(job_id) app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" cache_result = redis_client.get(app_annotation_job_key) @@ -48,60 +83,111 @@ class AnnotationReplyActionStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 +# Define annotation list response model +annotation_list_fields = { + "data": fields.List(fields.Nested(annotation_fields)), + "has_more": fields.Boolean, + "limit": fields.Integer, + "total": fields.Integer, + "page": fields.Integer, +} + + +def build_annotation_list_model(api_or_ns: Api | Namespace): + """Build the annotation list model for the API or Namespace.""" + copied_annotation_list_fields = annotation_list_fields.copy() + copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns))) + return api_or_ns.model("AnnotationList", copied_annotation_list_fields) + + +@service_api_ns.route("/apps/annotations") class AnnotationListApi(Resource): + @service_api_ns.doc("list_annotations") + @service_api_ns.doc(description="List annotations for the application") + @service_api_ns.doc( + responses={ + 200: "Annotations retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token + @service_api_ns.marshal_with(build_annotation_list_model(service_api_ns)) def get(self, app_model: App): + """List annotations for the application.""" page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) keyword = request.args.get("keyword", default="", type=str) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword) - response = { - "data": marshal(annotation_list, annotation_fields), + return { + "data": annotation_list, "has_more": len(annotation_list) == limit, "limit": limit, "total": total, "page": page, } - return response, 200 + @service_api_ns.expect(annotation_create_parser) + @service_api_ns.doc("create_annotation") + @service_api_ns.doc(description="Create a new annotation") + @service_api_ns.doc( + responses={ + 201: "Annotation created successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token - @marshal_with(annotation_fields) + @service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED) def post(self, app_model: App): - parser = reqparse.RequestParser() - parser.add_argument("question", required=True, type=str, location="json") - parser.add_argument("answer", required=True, type=str, location="json") - args = parser.parse_args() + """Create a new annotation.""" + args = annotation_create_parser.parse_args() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) - return annotation + return annotation, 201 +@service_api_ns.route("/apps/annotations/") class AnnotationUpdateDeleteApi(Resource): + @service_api_ns.expect(annotation_create_parser) + @service_api_ns.doc("update_annotation") + @service_api_ns.doc(description="Update an existing annotation") + @service_api_ns.doc(params={"annotation_id": "Annotation ID"}) + @service_api_ns.doc( + responses={ + 200: "Annotation updated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Annotation not found", + } + ) @validate_app_token - @marshal_with(annotation_fields) + @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) def put(self, app_model: App, annotation_id): + """Update an existing annotation.""" if not current_user.is_editor: raise Forbidden() annotation_id = str(annotation_id) - parser = reqparse.RequestParser() - parser.add_argument("question", required=True, type=str, location="json") - parser.add_argument("answer", required=True, type=str, location="json") - args = parser.parse_args() + args = annotation_create_parser.parse_args() annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) return annotation + @service_api_ns.doc("delete_annotation") + @service_api_ns.doc(description="Delete an annotation") + @service_api_ns.doc(params={"annotation_id": "Annotation ID"}) + @service_api_ns.doc( + responses={ + 204: "Annotation deleted successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Annotation not found", + } + ) @validate_app_token def delete(self, app_model: App, annotation_id): + """Delete an annotation.""" if not current_user.is_editor: raise Forbidden() annotation_id = str(annotation_id) AppAnnotationService.delete_app_annotation(app_model.id, annotation_id) return {"result": "success"}, 204 - - -api.add_resource(AnnotationReplyActionApi, "/apps/annotation-reply/") -api.add_resource(AnnotationReplyActionStatusApi, "/apps/annotation-reply//status/") -api.add_resource(AnnotationListApi, "/apps/annotations") -api.add_resource(AnnotationUpdateDeleteApi, "/apps/annotations/") diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index b881942dd..2dbeed1d6 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,7 +1,7 @@ -from flask_restx import Resource, marshal_with +from flask_restx import Resource -from controllers.common import fields -from controllers.service_api import api +from controllers.common.fields import build_parameters_model +from controllers.service_api import service_api_ns from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict @@ -9,13 +9,26 @@ from models.model import App, AppMode from services.app_service import AppService +@service_api_ns.route("/parameters") class AppParameterApi(Resource): """Resource for app variables.""" + @service_api_ns.doc("get_app_parameters") + @service_api_ns.doc(description="Retrieve application input parameters and configuration") + @service_api_ns.doc( + responses={ + 200: "Parameters retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Application not found", + } + ) @validate_app_token - @marshal_with(fields.parameters_fields) + @service_api_ns.marshal_with(build_parameters_model(service_api_ns)) def get(self, app_model: App): - """Retrieve app parameters.""" + """Retrieve app parameters. + + Returns the input form parameters and configuration for the application. + """ if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: @@ -35,17 +48,43 @@ class AppParameterApi(Resource): return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) +@service_api_ns.route("/meta") class AppMetaApi(Resource): + @service_api_ns.doc("get_app_meta") + @service_api_ns.doc(description="Get application metadata") + @service_api_ns.doc( + responses={ + 200: "Metadata retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Application not found", + } + ) @validate_app_token def get(self, app_model: App): - """Get app meta""" + """Get app metadata. + + Returns metadata about the application including configuration and settings. + """ return AppService().get_app_meta(app_model) +@service_api_ns.route("/info") class AppInfoApi(Resource): + @service_api_ns.doc("get_app_info") + @service_api_ns.doc(description="Get basic application information") + @service_api_ns.doc( + responses={ + 200: "Application info retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Application not found", + } + ) @validate_app_token def get(self, app_model: App): - """Get app information""" + """Get app information. + + Returns basic information about the application including name, description, tags, and mode. + """ tags = [tag.name for tag in app_model.tags] return { "name": app_model.name, @@ -54,8 +93,3 @@ class AppInfoApi(Resource): "mode": app_model.mode, "author_name": app_model.author_name, } - - -api.add_resource(AppParameterApi, "/parameters") -api.add_resource(AppMetaApi, "/meta") -api.add_resource(AppInfoApi, "/info") diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 882bf59b5..61b3020a5 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -5,7 +5,7 @@ from flask_restx import Resource, reqparse from werkzeug.exceptions import InternalServerError import services -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -30,9 +30,26 @@ from services.errors.audio import ( ) +@service_api_ns.route("/audio-to-text") class AudioApi(Resource): + @service_api_ns.doc("audio_to_text") + @service_api_ns.doc(description="Convert audio to text using speech-to-text") + @service_api_ns.doc( + responses={ + 200: "Audio successfully transcribed", + 400: "Bad request - no audio or invalid audio", + 401: "Unauthorized - invalid API token", + 413: "Audio file too large", + 415: "Unsupported audio type", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) def post(self, app_model: App, end_user: EndUser): + """Convert audio to text using speech-to-text. + + Accepts an audio file upload and returns the transcribed text. + """ file = request.files["file"] try: @@ -65,16 +82,35 @@ class AudioApi(Resource): raise InternalServerError() +# Define parser for text-to-audio API +text_to_audio_parser = reqparse.RequestParser() +text_to_audio_parser.add_argument("message_id", type=str, required=False, location="json", help="Message ID") +text_to_audio_parser.add_argument("voice", type=str, location="json", help="Voice to use for TTS") +text_to_audio_parser.add_argument("text", type=str, location="json", help="Text to convert to audio") +text_to_audio_parser.add_argument("streaming", type=bool, location="json", help="Enable streaming response") + + +@service_api_ns.route("/text-to-audio") class TextApi(Resource): + @service_api_ns.expect(text_to_audio_parser) + @service_api_ns.doc("text_to_audio") + @service_api_ns.doc(description="Convert text to audio using text-to-speech") + @service_api_ns.doc( + responses={ + 200: "Text successfully converted to audio", + 400: "Bad request - invalid parameters", + 401: "Unauthorized - invalid API token", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) def post(self, app_model: App, end_user: EndUser): + """Convert text to audio using text-to-speech. + + Converts the provided text to audio using the specified voice. + """ try: - parser = reqparse.RequestParser() - parser.add_argument("message_id", type=str, required=False, location="json") - parser.add_argument("voice", type=str, location="json") - parser.add_argument("text", type=str, location="json") - parser.add_argument("streaming", type=bool, location="json") - args = parser.parse_args() + args = text_to_audio_parser.parse_args() message_id = args.get("message_id", None) text = args.get("text", None) @@ -108,7 +144,3 @@ class TextApi(Resource): except Exception as e: logging.exception("internal server error.") raise InternalServerError() - - -api.add_resource(AudioApi, "/audio-to-text") -api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 5f37f4fd9..dddb75d59 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -5,7 +5,7 @@ from flask_restx import Resource, reqparse from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( AppUnavailableError, CompletionRequestError, @@ -33,21 +33,68 @@ from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError +# Define parser for completion API +completion_parser = reqparse.RequestParser() +completion_parser.add_argument( + "inputs", type=dict, required=True, location="json", help="Input parameters for completion" +) +completion_parser.add_argument("query", type=str, location="json", default="", help="The query string") +completion_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments") +completion_parser.add_argument( + "response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode" +) +completion_parser.add_argument( + "retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source" +) +# Define parser for chat API +chat_parser = reqparse.RequestParser() +chat_parser.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat") +chat_parser.add_argument("query", type=str, required=True, location="json", help="The chat query") +chat_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments") +chat_parser.add_argument( + "response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode" +) +chat_parser.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID") +chat_parser.add_argument( + "retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source" +) +chat_parser.add_argument( + "auto_generate_name", + type=bool, + required=False, + default=True, + location="json", + help="Auto generate conversation name", +) +chat_parser.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat") + + +@service_api_ns.route("/completion-messages") class CompletionApi(Resource): + @service_api_ns.expect(completion_parser) + @service_api_ns.doc("create_completion") + @service_api_ns.doc(description="Create a completion for the given prompt") + @service_api_ns.doc( + responses={ + 200: "Completion created successfully", + 400: "Bad request - invalid parameters", + 401: "Unauthorized - invalid API token", + 404: "Conversation not found", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): + """Create a completion for the given prompt. + + This endpoint generates a completion based on the provided inputs and query. + Supports both blocking and streaming response modes. + """ if app_model.mode != "completion": raise AppUnavailableError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, location="json", default="") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") - - args = parser.parse_args() + args = completion_parser.parse_args() external_trace_id = get_external_trace_id(request) if external_trace_id: args["external_trace_id"] = external_trace_id @@ -88,9 +135,21 @@ class CompletionApi(Resource): raise InternalServerError() +@service_api_ns.route("/completion-messages//stop") class CompletionStopApi(Resource): + @service_api_ns.doc("stop_completion") + @service_api_ns.doc(description="Stop a running completion task") + @service_api_ns.doc(params={"task_id": "The ID of the task to stop"}) + @service_api_ns.doc( + responses={ + 200: "Task stopped successfully", + 401: "Unauthorized - invalid API token", + 404: "Task not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) - def post(self, app_model: App, end_user: EndUser, task_id): + def post(self, app_model: App, end_user: EndUser, task_id: str): + """Stop a running completion task.""" if app_model.mode != "completion": raise AppUnavailableError() @@ -99,23 +158,33 @@ class CompletionStopApi(Resource): return {"result": "success"}, 200 +@service_api_ns.route("/chat-messages") class ChatApi(Resource): + @service_api_ns.expect(chat_parser) + @service_api_ns.doc("create_chat_message") + @service_api_ns.doc(description="Send a message in a chat conversation") + @service_api_ns.doc( + responses={ + 200: "Message sent successfully", + 400: "Bad request - invalid parameters or workflow issues", + 401: "Unauthorized - invalid API token", + 404: "Conversation or workflow not found", + 429: "Rate limit exceeded", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): + """Send a message in a chat conversation. + + This endpoint handles chat messages for chat, agent chat, and advanced chat applications. + Supports conversation management and both blocking and streaming response modes. + """ app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, required=True, location="json") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - parser.add_argument("conversation_id", type=uuid_value, location="json") - parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") - parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") - parser.add_argument("workflow_id", type=str, required=False, location="json") - args = parser.parse_args() + args = chat_parser.parse_args() external_trace_id = get_external_trace_id(request) if external_trace_id: @@ -159,9 +228,21 @@ class ChatApi(Resource): raise InternalServerError() +@service_api_ns.route("/chat-messages//stop") class ChatStopApi(Resource): + @service_api_ns.doc("stop_chat_message") + @service_api_ns.doc(description="Stop a running chat message generation") + @service_api_ns.doc(params={"task_id": "The ID of the task to stop"}) + @service_api_ns.doc( + responses={ + 200: "Task stopped successfully", + 401: "Unauthorized - invalid API token", + 404: "Task not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) - def post(self, app_model: App, end_user: EndUser, task_id): + def post(self, app_model: App, end_user: EndUser, task_id: str): + """Stop a running chat message generation.""" app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -169,9 +250,3 @@ class ChatStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {"result": "success"}, 200 - - -api.add_resource(CompletionApi, "/completion-messages") -api.add_resource(CompletionStopApi, "/completion-messages//stop") -api.add_resource(ChatApi, "/chat-messages") -api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index c2cccb6ef..4860bf3a7 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,48 +1,97 @@ -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, reqparse from flask_restx.inputs import int_range from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound import services -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import ( - conversation_delete_fields, - conversation_infinite_scroll_pagination_fields, - simple_conversation_fields, + build_conversation_delete_model, + build_conversation_infinite_scroll_pagination_model, + build_simple_conversation_model, ) from fields.conversation_variable_fields import ( - conversation_variable_fields, - conversation_variable_infinite_scroll_pagination_fields, + build_conversation_variable_infinite_scroll_pagination_model, + build_conversation_variable_model, ) from libs.helper import uuid_value from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService +# Define parsers for conversation APIs +conversation_list_parser = reqparse.RequestParser() +conversation_list_parser.add_argument( + "last_id", type=uuid_value, location="args", help="Last conversation ID for pagination" +) +conversation_list_parser.add_argument( + "limit", + type=int_range(1, 100), + required=False, + default=20, + location="args", + help="Number of conversations to return", +) +conversation_list_parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + help="Sort order for conversations", +) +conversation_rename_parser = reqparse.RequestParser() +conversation_rename_parser.add_argument("name", type=str, required=False, location="json", help="New conversation name") +conversation_rename_parser.add_argument( + "auto_generate", type=bool, required=False, default=False, location="json", help="Auto-generate conversation name" +) + +conversation_variables_parser = reqparse.RequestParser() +conversation_variables_parser.add_argument( + "last_id", type=uuid_value, location="args", help="Last variable ID for pagination" +) +conversation_variables_parser.add_argument( + "limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of variables to return" +) + +conversation_variable_update_parser = reqparse.RequestParser() +# using lambda is for passing the already-typed value without modification +# if no lambda, it will be converted to string +# the string cannot be converted using json.loads +conversation_variable_update_parser.add_argument( + "value", required=True, location="json", type=lambda x: x, help="New value for the conversation variable" +) + + +@service_api_ns.route("/conversations") class ConversationApi(Resource): + @service_api_ns.expect(conversation_list_parser) + @service_api_ns.doc("list_conversations") + @service_api_ns.doc(description="List all conversations for the current user") + @service_api_ns.doc( + responses={ + 200: "Conversations retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Last conversation not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @marshal_with(conversation_infinite_scroll_pagination_fields) + @service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser): + """List all conversations for the current user. + + Supports pagination using last_id and limit parameters. + """ app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - parser.add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", - ) - args = parser.parse_args() + args = conversation_list_parser.parse_args() try: with Session(db.engine) as session: @@ -59,10 +108,22 @@ class ConversationApi(Resource): raise NotFound("Last Conversation Not Exists.") +@service_api_ns.route("/conversations/") class ConversationDetailApi(Resource): + @service_api_ns.doc("delete_conversation") + @service_api_ns.doc(description="Delete a specific conversation") + @service_api_ns.doc(params={"c_id": "Conversation ID"}) + @service_api_ns.doc( + responses={ + 204: "Conversation deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Conversation not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @marshal_with(conversation_delete_fields) + @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204) def delete(self, app_model: App, end_user: EndUser, c_id): + """Delete a specific conversation.""" app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -76,20 +137,30 @@ class ConversationDetailApi(Resource): return {"result": "success"}, 204 +@service_api_ns.route("/conversations//name") class ConversationRenameApi(Resource): + @service_api_ns.expect(conversation_rename_parser) + @service_api_ns.doc("rename_conversation") + @service_api_ns.doc(description="Rename a conversation or auto-generate a name") + @service_api_ns.doc(params={"c_id": "Conversation ID"}) + @service_api_ns.doc( + responses={ + 200: "Conversation renamed successfully", + 401: "Unauthorized - invalid API token", + 404: "Conversation not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @marshal_with(simple_conversation_fields) + @service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns)) def post(self, app_model: App, end_user: EndUser, c_id): + """Rename a conversation or auto-generate a name.""" app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=False, location="json") - parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") - args = parser.parse_args() + args = conversation_rename_parser.parse_args() try: return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) @@ -97,10 +168,26 @@ class ConversationRenameApi(Resource): raise NotFound("Conversation Not Exists.") +@service_api_ns.route("/conversations//variables") class ConversationVariablesApi(Resource): + @service_api_ns.expect(conversation_variables_parser) + @service_api_ns.doc("list_conversation_variables") + @service_api_ns.doc(description="List all variables for a conversation") + @service_api_ns.doc(params={"c_id": "Conversation ID"}) + @service_api_ns.doc( + responses={ + 200: "Variables retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Conversation not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @marshal_with(conversation_variable_infinite_scroll_pagination_fields) + @service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser, c_id): + """List all variables for a conversation. + + Conversational variables are only available for chat applications. + """ # conversational variable only for chat app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -108,10 +195,7 @@ class ConversationVariablesApi(Resource): conversation_id = str(c_id) - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - args = parser.parse_args() + args = conversation_variables_parser.parse_args() try: return ConversationService.get_conversational_variable( @@ -121,11 +205,28 @@ class ConversationVariablesApi(Resource): raise NotFound("Conversation Not Exists.") +@service_api_ns.route("/conversations//variables/") class ConversationVariableDetailApi(Resource): + @service_api_ns.expect(conversation_variable_update_parser) + @service_api_ns.doc("update_conversation_variable") + @service_api_ns.doc(description="Update a conversation variable's value") + @service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"}) + @service_api_ns.doc( + responses={ + 200: "Variable updated successfully", + 400: "Bad request - type mismatch", + 401: "Unauthorized - invalid API token", + 404: "Conversation or variable not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @marshal_with(conversation_variable_fields) + @service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns)) def put(self, app_model: App, end_user: EndUser, c_id, variable_id): - """Update a conversation variable's value""" + """Update a conversation variable's value. + + Allows updating the value of a specific conversation variable. + The value must match the variable's expected type. + """ app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -133,12 +234,7 @@ class ConversationVariableDetailApi(Resource): conversation_id = str(c_id) variable_id = str(variable_id) - parser = reqparse.RequestParser() - # using lambda is for passing the already-typed value without modification - # if no lambda, it will be converted to string - # the string cannot be converted using json.loads - parser.add_argument("value", required=True, location="json", type=lambda x: x) - args = parser.parse_args() + args = conversation_variable_update_parser.parse_args() try: return ConversationService.update_conversation_variable( @@ -150,15 +246,3 @@ class ConversationVariableDetailApi(Resource): raise NotFound("Conversation Variable Not Exists.") except services.errors.conversation.ConversationVariableTypeMismatchError as e: raise BadRequest(str(e)) - - -api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="conversation_name") -api.add_resource(ConversationApi, "/conversations") -api.add_resource(ConversationDetailApi, "/conversations/", endpoint="conversation_detail") -api.add_resource(ConversationVariablesApi, "/conversations//variables", endpoint="conversation_variables") -api.add_resource( - ConversationVariableDetailApi, - "/conversations//variables/", - endpoint="conversation_variable_detail", - methods=["PUT"], -) diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index 2efe51359..05f27545b 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -1,5 +1,6 @@ from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Resource +from flask_restx.api import HTTPStatus import services from controllers.common.errors import ( @@ -9,17 +10,33 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token -from fields.file_fields import file_fields +from fields.file_fields import build_file_model from models.model import App, EndUser from services.file_service import FileService +@service_api_ns.route("/files/upload") class FileApi(Resource): + @service_api_ns.doc("upload_file") + @service_api_ns.doc(description="Upload a file for use in conversations") + @service_api_ns.doc( + responses={ + 201: "File uploaded successfully", + 400: "Bad request - no file or invalid file", + 401: "Unauthorized - invalid API token", + 413: "File too large", + 415: "Unsupported file type", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) - @marshal_with(file_fields) + @service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED) def post(self, app_model: App, end_user: EndUser): + """Upload a file for use in conversations. + + Accepts a single file upload via multipart/form-data. + """ # check file if "file" not in request.files: raise NoFileUploadedError() @@ -47,6 +64,3 @@ class FileApi(Resource): raise UnsupportedFileTypeError() return upload_file, 201 - - -api.add_resource(FileApi, "/files/upload") diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py index d51862b35..84d80ea10 100644 --- a/api/controllers/service_api/app/file_preview.py +++ b/api/controllers/service_api/app/file_preview.py @@ -4,7 +4,7 @@ from urllib.parse import quote from flask import Response from flask_restx import Resource, reqparse -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( FileAccessDeniedError, FileNotFoundError, @@ -17,6 +17,14 @@ from models.model import App, EndUser, Message, MessageFile, UploadFile logger = logging.getLogger(__name__) +# Define parser for file preview API +file_preview_parser = reqparse.RequestParser() +file_preview_parser.add_argument( + "as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment" +) + + +@service_api_ns.route("/files//preview") class FilePreviewApi(Resource): """ Service API File Preview endpoint @@ -25,33 +33,30 @@ class FilePreviewApi(Resource): Files can only be accessed if they belong to messages within the requesting app's context. """ + @service_api_ns.expect(file_preview_parser) + @service_api_ns.doc("preview_file") + @service_api_ns.doc(description="Preview or download a file uploaded via Service API") + @service_api_ns.doc(params={"file_id": "UUID of the file to preview"}) + @service_api_ns.doc( + responses={ + 200: "File retrieved successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - file access denied", + 404: "File not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) def get(self, app_model: App, end_user: EndUser, file_id: str): """ - Preview/Download a file that was uploaded via Service API + Preview/Download a file that was uploaded via Service API. - Args: - app_model: The authenticated app model - end_user: The authenticated end user (optional) - file_id: UUID of the file to preview - - Query Parameters: - user: Optional user identifier - as_attachment: Boolean, whether to download as attachment (default: false) - - Returns: - Stream response with file content - - Raises: - FileNotFoundError: File does not exist - FileAccessDeniedError: File access denied (not owned by app) + Provides secure file preview/download functionality. + Files can only be accessed if they belong to messages within the requesting app's context. """ file_id = str(file_id) # Parse query parameters - parser = reqparse.RequestParser() - parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") - args = parser.parse_args() + args = file_preview_parser.parse_args() # Validate file ownership and get file objects message_file, upload_file = self._validate_file_ownership(file_id, app_model.id) @@ -180,7 +185,3 @@ class FilePreviewApi(Resource): response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour return response - - -# Register the API endpoint -api.add_resource(FilePreviewApi, "/files//preview") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 52efebe87..ad3fac700 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,17 +1,17 @@ import json import logging -from flask_restx import Resource, fields, marshal_with, reqparse +from flask_restx import Api, Namespace, Resource, fields, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom -from fields.conversation_fields import message_file_fields -from fields.message_fields import agent_thought_fields, feedback_fields +from fields.conversation_fields import build_message_file_model +from fields.message_fields import build_agent_thought_model, build_feedback_model from fields.raws import FilesContainedField from libs.helper import TimestampField, uuid_value from models.model import App, AppMode, EndUser @@ -22,8 +22,37 @@ from services.errors.message import ( ) from services.message_service import MessageService +# Define parsers for message APIs +message_list_parser = reqparse.RequestParser() +message_list_parser.add_argument( + "conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID" +) +message_list_parser.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination") +message_list_parser.add_argument( + "limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of messages to return" +) -class MessageListApi(Resource): +message_feedback_parser = reqparse.RequestParser() +message_feedback_parser.add_argument( + "rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating" +) +message_feedback_parser.add_argument("content", type=str, location="json", help="Feedback content") + +feedback_list_parser = reqparse.RequestParser() +feedback_list_parser.add_argument("page", type=int, default=1, location="args", help="Page number") +feedback_list_parser.add_argument( + "limit", type=int_range(1, 101), required=False, default=20, location="args", help="Number of feedbacks per page" +) + + +def build_message_model(api_or_ns: Api | Namespace): + """Build the message model for the API or Namespace.""" + # First build the nested models + feedback_model = build_feedback_model(api_or_ns) + agent_thought_model = build_agent_thought_model(api_or_ns) + message_file_model = build_message_file_model(api_or_ns) + + # Then build the message fields with nested models message_fields = { "id": fields.String, "conversation_id": fields.String, @@ -31,37 +60,58 @@ class MessageListApi(Resource): "inputs": FilesContainedField, "query": fields.String, "answer": fields.String(attribute="re_sign_file_url_answer"), - "message_files": fields.List(fields.Nested(message_file_fields)), - "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "message_files": fields.List(fields.Nested(message_file_model)), + "feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True), "retriever_resources": fields.Raw( attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", []) if obj.message_metadata else [] ), "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), "status": fields.String, "error": fields.String, } + return api_or_ns.model("Message", message_fields) + + +def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): + """Build the message infinite scroll pagination model for the API or Namespace.""" + # Build the nested message model first + message_model = build_message_model(api_or_ns) message_infinite_scroll_pagination_fields = { "limit": fields.Integer, "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_fields)), + "data": fields.List(fields.Nested(message_model)), } + return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields) + +@service_api_ns.route("/messages") +class MessageListApi(Resource): + @service_api_ns.expect(message_list_parser) + @service_api_ns.doc("list_messages") + @service_api_ns.doc(description="List messages in a conversation") + @service_api_ns.doc( + responses={ + 200: "Messages retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Conversation or first message not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @marshal_with(message_infinite_scroll_pagination_fields) + @service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser): + """List messages in a conversation. + + Retrieves messages with pagination support using first_id. + """ app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") - parser.add_argument("first_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - args = parser.parse_args() + args = message_list_parser.parse_args() try: return MessageService.pagination_by_first_id( @@ -73,15 +123,28 @@ class MessageListApi(Resource): raise NotFound("First Message Not Exists.") +@service_api_ns.route("/messages//feedbacks") class MessageFeedbackApi(Resource): + @service_api_ns.expect(message_feedback_parser) + @service_api_ns.doc("create_message_feedback") + @service_api_ns.doc(description="Submit feedback for a message") + @service_api_ns.doc(params={"message_id": "Message ID"}) + @service_api_ns.doc( + responses={ + 200: "Feedback submitted successfully", + 401: "Unauthorized - invalid API token", + 404: "Message not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, message_id): + """Submit feedback for a message. + + Allows users to rate messages as like/dislike and provide optional feedback content. + """ message_id = str(message_id) - parser = reqparse.RequestParser() - parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") - parser.add_argument("content", type=str, location="json") - args = parser.parse_args() + args = message_feedback_parser.parse_args() try: MessageService.create_feedback( @@ -97,21 +160,48 @@ class MessageFeedbackApi(Resource): return {"result": "success"} +@service_api_ns.route("/app/feedbacks") class AppGetFeedbacksApi(Resource): + @service_api_ns.expect(feedback_list_parser) + @service_api_ns.doc("get_app_feedbacks") + @service_api_ns.doc(description="Get all feedbacks for the application") + @service_api_ns.doc( + responses={ + 200: "Feedbacks retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token def get(self, app_model: App): - """Get All Feedbacks of an app""" - parser = reqparse.RequestParser() - parser.add_argument("page", type=int, default=1, location="args") - parser.add_argument("limit", type=int_range(1, 101), required=False, default=20, location="args") - args = parser.parse_args() + """Get all feedbacks for the application. + + Returns paginated list of all feedback submitted for messages in this app. + """ + args = feedback_list_parser.parse_args() feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"]) return {"data": feedbacks} +@service_api_ns.route("/messages//suggested") class MessageSuggestedApi(Resource): + @service_api_ns.doc("get_suggested_questions") + @service_api_ns.doc(description="Get suggested follow-up questions for a message") + @service_api_ns.doc(params={"message_id": "Message ID"}) + @service_api_ns.doc( + responses={ + 200: "Suggested questions retrieved successfully", + 400: "Suggested questions feature is disabled", + 401: "Unauthorized - invalid API token", + 404: "Message not found", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True)) def get(self, app_model: App, end_user: EndUser, message_id): + """Get suggested follow-up questions for a message. + + Returns AI-generated follow-up questions based on the message content. + """ message_id = str(message_id) app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -130,9 +220,3 @@ class MessageSuggestedApi(Resource): raise InternalServerError() return {"result": "success", "data": questions} - - -api.add_resource(MessageListApi, "/messages") -api.add_resource(MessageFeedbackApi, "/messages//feedbacks") -api.add_resource(MessageSuggestedApi, "/messages//suggested") -api.add_resource(AppGetFeedbacksApi, "/app/feedbacks") diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py index a57ab9179..9f8324a84 100644 --- a/api/controllers/service_api/app/site.py +++ b/api/controllers/service_api/app/site.py @@ -1,30 +1,41 @@ -from flask_restx import Resource, marshal_with +from flask_restx import Resource from werkzeug.exceptions import Forbidden -from controllers.common import fields -from controllers.service_api import api +from controllers.common.fields import build_site_model +from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_database import db from models.account import TenantStatus from models.model import App, Site +@service_api_ns.route("/site") class AppSiteApi(Resource): """Resource for app sites.""" + @service_api_ns.doc("get_app_site") + @service_api_ns.doc(description="Get application site configuration") + @service_api_ns.doc( + responses={ + 200: "Site configuration retrieved successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - site not found or tenant archived", + } + ) @validate_app_token - @marshal_with(fields.site_fields) + @service_api_ns.marshal_with(build_site_model(service_api_ns)) def get(self, app_model: App): - """Retrieve app site info.""" + """Retrieve app site info. + + Returns the site configuration for the application including theme, icons, and text. + """ site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: raise Forbidden() + assert app_model.tenant if app_model.tenant.status == TenantStatus.ARCHIVE: raise Forbidden() return site - - -api.add_resource(AppSiteApi, "/site") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 24659bdb3..19e2e67d7 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -2,12 +2,12 @@ import logging from dateutil.parser import isoparse from flask import request -from flask_restx import Resource, fields, marshal_with, reqparse +from flask_restx import Api, Namespace, Resource, fields, reqparse from flask_restx.inputs import int_range from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( CompletionRequestError, NotWorkflowAppError, @@ -28,7 +28,7 @@ from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from extensions.ext_database import db -from fields.workflow_app_log_fields import workflow_app_log_pagination_fields +from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model from libs import helper from libs.helper import TimestampField from models.model import App, AppMode, EndUser @@ -40,6 +40,34 @@ from services.workflow_app_service import WorkflowAppService logger = logging.getLogger(__name__) +# Define parsers for workflow APIs +workflow_run_parser = reqparse.RequestParser() +workflow_run_parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") +workflow_run_parser.add_argument("files", type=list, required=False, location="json") +workflow_run_parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + +workflow_log_parser = reqparse.RequestParser() +workflow_log_parser.add_argument("keyword", type=str, location="args") +workflow_log_parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") +workflow_log_parser.add_argument("created_at__before", type=str, location="args") +workflow_log_parser.add_argument("created_at__after", type=str, location="args") +workflow_log_parser.add_argument( + "created_by_end_user_session_id", + type=str, + location="args", + required=False, + default=None, +) +workflow_log_parser.add_argument( + "created_by_account", + type=str, + location="args", + required=False, + default=None, +) +workflow_log_parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") +workflow_log_parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") + workflow_run_fields = { "id": fields.String, "workflow_id": fields.String, @@ -55,12 +83,29 @@ workflow_run_fields = { } +def build_workflow_run_model(api_or_ns: Api | Namespace): + """Build the workflow run model for the API or Namespace.""" + return api_or_ns.model("WorkflowRun", workflow_run_fields) + + +@service_api_ns.route("/workflows/run/") class WorkflowRunDetailApi(Resource): + @service_api_ns.doc("get_workflow_run_detail") + @service_api_ns.doc(description="Get workflow run details") + @service_api_ns.doc(params={"workflow_run_id": "Workflow run ID"}) + @service_api_ns.doc( + responses={ + 200: "Workflow run details retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Workflow run not found", + } + ) @validate_app_token - @marshal_with(workflow_run_fields) + @service_api_ns.marshal_with(build_workflow_run_model(service_api_ns)) def get(self, app_model: App, workflow_run_id: str): - """ - Get a workflow task running detail + """Get a workflow task running detail. + + Returns detailed information about a specific workflow run. """ app_mode = AppMode.value_of(app_model.mode) if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]: @@ -78,21 +123,33 @@ class WorkflowRunDetailApi(Resource): return workflow_run +@service_api_ns.route("/workflows/run") class WorkflowRunApi(Resource): + @service_api_ns.expect(workflow_run_parser) + @service_api_ns.doc("run_workflow") + @service_api_ns.doc(description="Execute a workflow") + @service_api_ns.doc( + responses={ + 200: "Workflow executed successfully", + 400: "Bad request - invalid parameters or workflow issues", + 401: "Unauthorized - invalid API token", + 404: "Workflow not found", + 429: "Rate limit exceeded", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): - """ - Run workflow + """Execute a workflow. + + Runs a workflow with the provided inputs and returns the results. + Supports both blocking and streaming response modes. """ app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - args = parser.parse_args() + args = workflow_run_parser.parse_args() external_trace_id = get_external_trace_id(request) if external_trace_id: args["external_trace_id"] = external_trace_id @@ -121,21 +178,33 @@ class WorkflowRunApi(Resource): raise InternalServerError() +@service_api_ns.route("/workflows//run") class WorkflowRunByIdApi(Resource): + @service_api_ns.expect(workflow_run_parser) + @service_api_ns.doc("run_workflow_by_id") + @service_api_ns.doc(description="Execute a specific workflow by ID") + @service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"}) + @service_api_ns.doc( + responses={ + 200: "Workflow executed successfully", + 400: "Bad request - invalid parameters or workflow issues", + 401: "Unauthorized - invalid API token", + 404: "Workflow not found", + 429: "Rate limit exceeded", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, workflow_id: str): - """ - Run specific workflow by ID + """Run specific workflow by ID. + + Executes a specific workflow version identified by its ID. """ app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - args = parser.parse_args() + args = workflow_run_parser.parse_args() # Add workflow_id to args for AppGenerateService args["workflow_id"] = workflow_id @@ -174,12 +243,21 @@ class WorkflowRunByIdApi(Resource): raise InternalServerError() +@service_api_ns.route("/workflows/tasks//stop") class WorkflowTaskStopApi(Resource): + @service_api_ns.doc("stop_workflow_task") + @service_api_ns.doc(description="Stop a running workflow task") + @service_api_ns.doc(params={"task_id": "Task ID to stop"}) + @service_api_ns.doc( + responses={ + 200: "Task stopped successfully", + 401: "Unauthorized - invalid API token", + 404: "Task not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id: str): - """ - Stop workflow task - """ + """Stop a running workflow task.""" app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() @@ -189,35 +267,25 @@ class WorkflowTaskStopApi(Resource): return {"result": "success"} +@service_api_ns.route("/workflows/logs") class WorkflowAppLogApi(Resource): + @service_api_ns.expect(workflow_log_parser) + @service_api_ns.doc("get_workflow_logs") + @service_api_ns.doc(description="Get workflow execution logs") + @service_api_ns.doc( + responses={ + 200: "Logs retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token - @marshal_with(workflow_app_log_pagination_fields) + @service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns)) def get(self, app_model: App): + """Get workflow app logs. + + Returns paginated workflow execution logs with filtering options. """ - Get workflow app logs - """ - parser = reqparse.RequestParser() - parser.add_argument("keyword", type=str, location="args") - parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") - parser.add_argument("created_at__before", type=str, location="args") - parser.add_argument("created_at__after", type=str, location="args") - parser.add_argument( - "created_by_end_user_session_id", - type=str, - location="args", - required=False, - default=None, - ) - parser.add_argument( - "created_by_account", - type=str, - location="args", - required=False, - default=None, - ) - parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") - parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") - args = parser.parse_args() + args = workflow_log_parser.parse_args() args.status = WorkflowExecutionStatus(args.status) if args.status else None if args.created_at__before: @@ -243,10 +311,3 @@ class WorkflowAppLogApi(Resource): ) return workflow_app_log_pagination - - -api.add_resource(WorkflowRunApi, "/workflows/run") -api.add_resource(WorkflowRunDetailApi, "/workflows/run/") -api.add_resource(WorkflowRunByIdApi, "/workflows//run") -api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") -api.add_resource(WorkflowAppLogApi, "/workflows/logs") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 2f3a2d6ea..c486b0480 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,11 +1,11 @@ from typing import Literal from flask import request -from flask_restx import marshal, marshal_with, reqparse +from flask_restx import marshal, reqparse from werkzeug.exceptions import Forbidden, NotFound import services.dataset_service -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.wraps import ( DatasetApiResource, @@ -16,7 +16,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.plugin.entities.plugin import ModelProviderID from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields -from fields.tag_fields import tag_fields +from fields.tag_fields import build_dataset_tag_fields from libs.login import current_user from models.dataset import Dataset, DatasetPermissionEnum from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService @@ -36,12 +36,171 @@ def _validate_description_length(description): return description +# Define parsers for dataset operations +dataset_create_parser = reqparse.RequestParser() +dataset_create_parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, +) +dataset_create_parser.add_argument( + "description", + type=_validate_description_length, + nullable=True, + required=False, + default="", +) +dataset_create_parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + help="Invalid indexing technique.", +) +dataset_create_parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + required=False, + nullable=False, +) +dataset_create_parser.add_argument( + "external_knowledge_api_id", + type=str, + nullable=True, + required=False, + default="_validate_name", +) +dataset_create_parser.add_argument( + "provider", + type=str, + nullable=True, + required=False, + default="vendor", +) +dataset_create_parser.add_argument( + "external_knowledge_id", + type=str, + nullable=True, + required=False, +) +dataset_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") +dataset_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") +dataset_create_parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") + +dataset_update_parser = reqparse.RequestParser() +dataset_update_parser.add_argument( + "name", + nullable=False, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, +) +dataset_update_parser.add_argument( + "description", location="json", store_missing=False, type=_validate_description_length +) +dataset_update_parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", +) +dataset_update_parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", +) +dataset_update_parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") +dataset_update_parser.add_argument( + "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." +) +dataset_update_parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") +dataset_update_parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") +dataset_update_parser.add_argument( + "external_retrieval_model", + type=dict, + required=False, + nullable=True, + location="json", + help="Invalid external retrieval model.", +) +dataset_update_parser.add_argument( + "external_knowledge_id", + type=str, + required=False, + nullable=True, + location="json", + help="Invalid external knowledge id.", +) +dataset_update_parser.add_argument( + "external_knowledge_api_id", + type=str, + required=False, + nullable=True, + location="json", + help="Invalid external knowledge api id.", +) + +tag_create_parser = reqparse.RequestParser() +tag_create_parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 50 characters.", + type=lambda x: x + if x and 1 <= len(x) <= 50 + else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), +) + +tag_update_parser = reqparse.RequestParser() +tag_update_parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 50 characters.", + type=lambda x: x + if x and 1 <= len(x) <= 50 + else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), +) +tag_update_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) + +tag_delete_parser = reqparse.RequestParser() +tag_delete_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) + +tag_binding_parser = reqparse.RequestParser() +tag_binding_parser.add_argument( + "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." +) +tag_binding_parser.add_argument( + "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." +) + +tag_unbinding_parser = reqparse.RequestParser() +tag_unbinding_parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") +tag_unbinding_parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") + + +@service_api_ns.route("/datasets") class DatasetListApi(DatasetApiResource): """Resource for datasets.""" + @service_api_ns.doc("list_datasets") + @service_api_ns.doc(description="List all datasets") + @service_api_ns.doc( + responses={ + 200: "Datasets retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) def get(self, tenant_id): """Resource for getting datasets.""" - page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) # provider = request.args.get("provider", default="vendor") @@ -76,65 +235,20 @@ class DatasetListApi(DatasetApiResource): response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 + @service_api_ns.expect(dataset_create_parser) + @service_api_ns.doc("create_dataset") + @service_api_ns.doc(description="Create a new dataset") + @service_api_ns.doc( + responses={ + 200: "Dataset created successfully", + 401: "Unauthorized - invalid API token", + 400: "Bad request - invalid parameters", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id): """Resource for creating datasets.""" - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - parser.add_argument( - "description", - type=_validate_description_length, - nullable=True, - required=False, - default="", - ) - parser.add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - help="Invalid indexing technique.", - ) - parser.add_argument( - "permission", - type=str, - location="json", - choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), - help="Invalid permission.", - required=False, - nullable=False, - ) - parser.add_argument( - "external_knowledge_api_id", - type=str, - nullable=True, - required=False, - default="_validate_name", - ) - parser.add_argument( - "provider", - type=str, - nullable=True, - required=False, - default="vendor", - ) - parser.add_argument( - "external_knowledge_id", - type=str, - nullable=True, - required=False, - ) - parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") - parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") - - args = parser.parse_args() + args = dataset_create_parser.parse_args() if args.get("embedding_model_provider"): DatasetService.check_embedding_model_setting( @@ -174,9 +288,21 @@ class DatasetListApi(DatasetApiResource): return marshal(dataset, dataset_detail_fields), 200 +@service_api_ns.route("/datasets/") class DatasetApi(DatasetApiResource): """Resource for dataset.""" + @service_api_ns.doc("get_dataset") + @service_api_ns.doc(description="Get a specific dataset by ID") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Dataset retrieved successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Dataset not found", + } + ) def get(self, _, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -216,6 +342,18 @@ class DatasetApi(DatasetApiResource): return data, 200 + @service_api_ns.expect(dataset_update_parser) + @service_api_ns.doc("update_dataset") + @service_api_ns.doc(description="Update an existing dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Dataset updated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, _, dataset_id): dataset_id_str = str(dataset_id) @@ -223,63 +361,7 @@ class DatasetApi(DatasetApiResource): if dataset is None: raise NotFound("Dataset not found.") - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) - parser.add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help="Invalid indexing technique.", - ) - parser.add_argument( - "permission", - type=str, - location="json", - choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), - help="Invalid permission.", - ) - parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") - parser.add_argument( - "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." - ) - parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") - parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") - - parser.add_argument( - "external_retrieval_model", - type=dict, - required=False, - nullable=True, - location="json", - help="Invalid external retrieval model.", - ) - - parser.add_argument( - "external_knowledge_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge id.", - ) - - parser.add_argument( - "external_knowledge_api_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge api id.", - ) - args = parser.parse_args() + args = dataset_update_parser.parse_args() data = request.get_json() # check embedding model setting @@ -327,6 +409,17 @@ class DatasetApi(DatasetApiResource): return result_data, 200 + @service_api_ns.doc("delete_dataset") + @service_api_ns.doc(description="Delete a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 204: "Dataset deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + 409: "Conflict - dataset is in use", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, _, dataset_id): """ @@ -357,9 +450,27 @@ class DatasetApi(DatasetApiResource): raise DatasetInUseError() +@service_api_ns.route("/datasets//documents/status/") class DocumentStatusApi(DatasetApiResource): """Resource for batch document status operations.""" + @service_api_ns.doc("update_document_status") + @service_api_ns.doc(description="Batch update document status") + @service_api_ns.doc( + params={ + "dataset_id": "Dataset ID", + "action": "Action to perform: 'enable', 'disable', 'archive', or 'un_archive'", + } + ) + @service_api_ns.doc( + responses={ + 200: "Document status updated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Dataset not found", + 400: "Bad request - invalid action", + } + ) def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): """ Batch update document status. @@ -407,53 +518,65 @@ class DocumentStatusApi(DatasetApiResource): return {"result": "success"}, 200 +@service_api_ns.route("/datasets/tags") class DatasetTagsApi(DatasetApiResource): + @service_api_ns.doc("list_dataset_tags") + @service_api_ns.doc(description="Get all knowledge type tags") + @service_api_ns.doc( + responses={ + 200: "Tags retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_dataset_token - @marshal_with(tag_fields) + @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def get(self, _, dataset_id): """Get all knowledge type tags.""" tags = TagService.get_tags("knowledge", current_user.current_tenant_id) return tags, 200 + @service_api_ns.expect(tag_create_parser) + @service_api_ns.doc("create_dataset_tag") + @service_api_ns.doc(description="Add a knowledge type tag") + @service_api_ns.doc( + responses={ + 200: "Tag created successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) + @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @validate_dataset_token def post(self, _, dataset_id): """Add a knowledge type tag.""" if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=DatasetTagsApi._validate_tag_name, - ) - - args = parser.parse_args() + args = tag_create_parser.parse_args() args["type"] = "knowledge" tag = TagService.save_tags(args) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} - return response, 200 + @service_api_ns.expect(tag_update_parser) + @service_api_ns.doc("update_dataset_tag") + @service_api_ns.doc(description="Update a knowledge type tag") + @service_api_ns.doc( + responses={ + 200: "Tag updated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) + @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @validate_dataset_token def patch(self, _, dataset_id): if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=DatasetTagsApi._validate_tag_name, - ) - parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) - args = parser.parse_args() + args = tag_update_parser.parse_args() args["type"] = "knowledge" tag = TagService.update_tags(args, args.get("tag_id")) @@ -463,66 +586,88 @@ class DatasetTagsApi(DatasetApiResource): return response, 200 + @service_api_ns.expect(tag_delete_parser) + @service_api_ns.doc("delete_dataset_tag") + @service_api_ns.doc(description="Delete a knowledge type tag") + @service_api_ns.doc( + responses={ + 204: "Tag deleted successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) @validate_dataset_token def delete(self, _, dataset_id): """Delete a knowledge type tag.""" if not current_user.is_editor: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) - args = parser.parse_args() + args = tag_delete_parser.parse_args() TagService.delete_tag(args.get("tag_id")) return 204 - @staticmethod - def _validate_tag_name(name): - if not name or len(name) < 1 or len(name) > 50: - raise ValueError("Name must be between 1 to 50 characters.") - return name - +@service_api_ns.route("/datasets/tags/binding") class DatasetTagBindingApi(DatasetApiResource): + @service_api_ns.expect(tag_binding_parser) + @service_api_ns.doc("bind_dataset_tags") + @service_api_ns.doc(description="Bind tags to a dataset") + @service_api_ns.doc( + responses={ + 204: "Tags bound successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) @validate_dataset_token def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument( - "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." - ) - parser.add_argument( - "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." - ) - - args = parser.parse_args() + args = tag_binding_parser.parse_args() args["type"] = "knowledge" TagService.save_tag_binding(args) return 204 +@service_api_ns.route("/datasets/tags/unbinding") class DatasetTagUnbindingApi(DatasetApiResource): + @service_api_ns.expect(tag_unbinding_parser) + @service_api_ns.doc("unbind_dataset_tag") + @service_api_ns.doc(description="Unbind a tag from a dataset") + @service_api_ns.doc( + responses={ + 204: "Tag unbound successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) @validate_dataset_token def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") - parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") - - args = parser.parse_args() + args = tag_unbinding_parser.parse_args() args["type"] = "knowledge" TagService.delete_tag_binding(args) return 204 +@service_api_ns.route("/datasets//tags") class DatasetTagsBindingStatusApi(DatasetApiResource): + @service_api_ns.doc("get_dataset_tags_binding_status") + @service_api_ns.doc(description="Get tags bound to a specific dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Tags retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_dataset_token def get(self, _, *args, **kwargs): """Get all knowledge type tags.""" @@ -531,12 +676,3 @@ class DatasetTagsBindingStatusApi(DatasetApiResource): tags_list = [{"id": tag.id, "name": tag.name} for tag in tags] response = {"data": tags_list, "total": len(tags)} return response, 200 - - -api.add_resource(DatasetListApi, "/datasets") -api.add_resource(DatasetApi, "/datasets/") -api.add_resource(DocumentStatusApi, "/datasets//documents/status/") -api.add_resource(DatasetTagsApi, "/datasets/tags") -api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding") -api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding") -api.add_resource(DatasetTagsBindingStatusApi, "/datasets//tags") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index c2e52fe00..43232229c 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -13,7 +13,7 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.dataset.error import ( ArchivedDocumentImmutableError, @@ -34,32 +34,64 @@ from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService +# Define parsers for document operations +document_text_create_parser = reqparse.RequestParser() +document_text_create_parser.add_argument("name", type=str, required=True, nullable=False, location="json") +document_text_create_parser.add_argument("text", type=str, required=True, nullable=False, location="json") +document_text_create_parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") +document_text_create_parser.add_argument("original_document_id", type=str, required=False, location="json") +document_text_create_parser.add_argument( + "doc_form", type=str, default="text_model", required=False, nullable=False, location="json" +) +document_text_create_parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" +) +document_text_create_parser.add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" +) +document_text_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") +document_text_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") +document_text_create_parser.add_argument( + "embedding_model_provider", type=str, required=False, nullable=True, location="json" +) +document_text_update_parser = reqparse.RequestParser() +document_text_update_parser.add_argument("name", type=str, required=False, nullable=True, location="json") +document_text_update_parser.add_argument("text", type=str, required=False, nullable=True, location="json") +document_text_update_parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") +document_text_update_parser.add_argument( + "doc_form", type=str, default="text_model", required=False, nullable=False, location="json" +) +document_text_update_parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" +) +document_text_update_parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") + + +@service_api_ns.route( + "/datasets//document/create_by_text", + "/datasets//document/create-by-text", +) class DocumentAddByTextApi(DatasetApiResource): """Resource for documents.""" + @service_api_ns.expect(document_text_create_parser) + @service_api_ns.doc("create_document_by_text") + @service_api_ns.doc(description="Create a new document by providing text content") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Document created successfully", + 401: "Unauthorized - invalid API token", + 400: "Bad request - invalid parameters", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("documents", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by text.""" - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, nullable=False, location="json") - parser.add_argument("text", type=str, required=True, nullable=False, location="json") - parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") - parser.add_argument("original_document_id", type=str, required=False, location="json") - parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - parser.add_argument( - "doc_language", type=str, default="English", required=False, nullable=False, location="json" - ) - parser.add_argument( - "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" - ) - parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") - parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") - - args = parser.parse_args() + args = document_text_create_parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) @@ -117,23 +149,29 @@ class DocumentAddByTextApi(DatasetApiResource): return documents_and_batch_fields, 200 +@service_api_ns.route( + "/datasets//documents//update_by_text", + "/datasets//documents//update-by-text", +) class DocumentUpdateByTextApi(DatasetApiResource): """Resource for update documents.""" + @service_api_ns.expect(document_text_update_parser) + @service_api_ns.doc("update_document_by_text") + @service_api_ns.doc(description="Update an existing document by providing text content") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Document updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Document not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by text.""" - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=False, nullable=True, location="json") - parser.add_argument("text", type=str, required=False, nullable=True, location="json") - parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") - parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - parser.add_argument( - "doc_language", type=str, default="English", required=False, nullable=False, location="json" - ) - parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") - args = parser.parse_args() + args = document_text_update_parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -187,9 +225,23 @@ class DocumentUpdateByTextApi(DatasetApiResource): return documents_and_batch_fields, 200 +@service_api_ns.route( + "/datasets//document/create_by_file", + "/datasets//document/create-by-file", +) class DocumentAddByFileApi(DatasetApiResource): """Resource for documents.""" + @service_api_ns.doc("create_document_by_file") + @service_api_ns.doc(description="Create a new document by uploading a file") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Document created successfully", + 401: "Unauthorized - invalid API token", + 400: "Bad request - invalid file or parameters", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("documents", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") @@ -281,9 +333,23 @@ class DocumentAddByFileApi(DatasetApiResource): return documents_and_batch_fields, 200 +@service_api_ns.route( + "/datasets//documents//update_by_file", + "/datasets//documents//update-by-file", +) class DocumentUpdateByFileApi(DatasetApiResource): """Resource for update documents.""" + @service_api_ns.doc("update_document_by_file") + @service_api_ns.doc(description="Update an existing document by uploading a file") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Document updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Document not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): @@ -358,7 +424,18 @@ class DocumentUpdateByFileApi(DatasetApiResource): return documents_and_batch_fields, 200 +@service_api_ns.route("/datasets//documents") class DocumentListApi(DatasetApiResource): + @service_api_ns.doc("list_documents") + @service_api_ns.doc(description="List all documents in a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Documents retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) def get(self, tenant_id, dataset_id): dataset_id = str(dataset_id) tenant_id = str(tenant_id) @@ -391,7 +468,18 @@ class DocumentListApi(DatasetApiResource): return response +@service_api_ns.route("/datasets//documents//indexing-status") class DocumentIndexingStatusApi(DatasetApiResource): + @service_api_ns.doc("get_document_indexing_status") + @service_api_ns.doc(description="Get indexing status for documents in a batch") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "batch": "Batch ID"}) + @service_api_ns.doc( + responses={ + 200: "Indexing status retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset or documents not found", + } + ) def get(self, tenant_id, dataset_id, batch): dataset_id = str(dataset_id) batch = str(batch) @@ -440,9 +528,21 @@ class DocumentIndexingStatusApi(DatasetApiResource): return data +@service_api_ns.route("/datasets//documents/") class DocumentApi(DatasetApiResource): METADATA_CHOICES = {"all", "only", "without"} + @service_api_ns.doc("get_document") + @service_api_ns.doc(description="Get a specific document by ID") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Document retrieved successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Document not found", + } + ) def get(self, tenant_id, dataset_id, document_id): dataset_id = str(dataset_id) document_id = str(document_id) @@ -534,6 +634,17 @@ class DocumentApi(DatasetApiResource): return response + @service_api_ns.doc("delete_document") + @service_api_ns.doc(description="Delete a document") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 204: "Document deleted successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - document is archived", + 404: "Document not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id, dataset_id, document_id): """Delete document.""" @@ -564,28 +675,3 @@ class DocumentApi(DatasetApiResource): raise DocumentIndexingError("Cannot delete document during indexing.") return 204 - - -api.add_resource( - DocumentAddByTextApi, - "/datasets//document/create_by_text", - "/datasets//document/create-by-text", -) -api.add_resource( - DocumentAddByFileApi, - "/datasets//document/create_by_file", - "/datasets//document/create-by-file", -) -api.add_resource( - DocumentUpdateByTextApi, - "/datasets//documents//update_by_text", - "/datasets//documents//update-by-text", -) -api.add_resource( - DocumentUpdateByFileApi, - "/datasets//documents//update_by_file", - "/datasets//documents//update-by-file", -) -api.add_resource(DocumentApi, "/datasets//documents/") -api.add_resource(DocumentListApi, "/datasets//documents") -api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index 52e9bca5d..d81287d56 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -1,11 +1,26 @@ from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check +@service_api_ns.route("/datasets//hit-testing", "/datasets//retrieve") class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): + @service_api_ns.doc("dataset_hit_testing") + @service_api_ns.doc(description="Perform hit testing on a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Hit testing results", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): + """Perform hit testing on a dataset. + + Tests retrieval performance for the specified dataset. + """ dataset_id_str = str(dataset_id) dataset = self.get_and_validate_dataset(dataset_id_str) @@ -13,6 +28,3 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): self.hit_testing_args_check(args) return self.perform_hit_testing(dataset, args) - - -api.add_resource(HitTestingApi, "/datasets//hit-testing", "/datasets//retrieve") diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index c1542c031..9defe6af0 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -4,7 +4,7 @@ from flask_login import current_user # type: ignore from flask_restx import marshal, reqparse from werkzeug.exceptions import NotFound -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check from fields.dataset_fields import dataset_metadata_fields from services.dataset_service import DatasetService @@ -14,14 +14,43 @@ from services.entities.knowledge_entities.knowledge_entities import ( ) from services.metadata_service import MetadataService +# Define parsers for metadata APIs +metadata_create_parser = reqparse.RequestParser() +metadata_create_parser.add_argument( + "type", type=str, required=True, nullable=False, location="json", help="Metadata type" +) +metadata_create_parser.add_argument( + "name", type=str, required=True, nullable=False, location="json", help="Metadata name" +) +metadata_update_parser = reqparse.RequestParser() +metadata_update_parser.add_argument( + "name", type=str, required=True, nullable=False, location="json", help="New metadata name" +) + +document_metadata_parser = reqparse.RequestParser() +document_metadata_parser.add_argument( + "operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data" +) + + +@service_api_ns.route("/datasets//metadata") class DatasetMetadataCreateServiceApi(DatasetApiResource): + @service_api_ns.expect(metadata_create_parser) + @service_api_ns.doc("create_dataset_metadata") + @service_api_ns.doc(description="Create metadata for a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 201: "Metadata created successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() + """Create metadata for a dataset.""" + args = metadata_create_parser.parse_args() metadata_args = MetadataArgs(**args) dataset_id_str = str(dataset_id) @@ -33,7 +62,18 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) return marshal(metadata, dataset_metadata_fields), 201 + @service_api_ns.doc("get_dataset_metadata") + @service_api_ns.doc(description="Get all metadata for a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Metadata retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) def get(self, tenant_id, dataset_id): + """Get all metadata for a dataset.""" dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -41,12 +81,23 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): return MetadataService.get_dataset_metadatas(dataset), 200 +@service_api_ns.route("/datasets//metadata/") class DatasetMetadataServiceApi(DatasetApiResource): + @service_api_ns.expect(metadata_update_parser) + @service_api_ns.doc("update_dataset_metadata") + @service_api_ns.doc(description="Update metadata name") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"}) + @service_api_ns.doc( + responses={ + 200: "Metadata updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset or metadata not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, tenant_id, dataset_id, metadata_id): - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() + """Update metadata name.""" + args = metadata_update_parser.parse_args() dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) @@ -58,8 +109,19 @@ class DatasetMetadataServiceApi(DatasetApiResource): metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) return marshal(metadata, dataset_metadata_fields), 200 + @service_api_ns.doc("delete_dataset_metadata") + @service_api_ns.doc(description="Delete metadata") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"}) + @service_api_ns.doc( + responses={ + 204: "Metadata deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset or metadata not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id, dataset_id, metadata_id): + """Delete metadata.""" dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -71,15 +133,37 @@ class DatasetMetadataServiceApi(DatasetApiResource): return 204 +@service_api_ns.route("/datasets/metadata/built-in") class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): + @service_api_ns.doc("get_built_in_fields") + @service_api_ns.doc(description="Get all built-in metadata fields") + @service_api_ns.doc( + responses={ + 200: "Built-in fields retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) def get(self, tenant_id): + """Get all built-in metadata fields.""" built_in_fields = MetadataService.get_built_in_fields() return {"fields": built_in_fields}, 200 +@service_api_ns.route("/datasets//metadata/built-in/") class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): + @service_api_ns.doc("toggle_built_in_field") + @service_api_ns.doc(description="Enable or disable built-in metadata field") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "action": "Action to perform: 'enable' or 'disable'"}) + @service_api_ns.doc( + responses={ + 200: "Action completed successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]): + """Enable or disable built-in metadata field.""" dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -93,29 +177,31 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): return 200 +@service_api_ns.route("/datasets//documents/metadata") class DocumentMetadataEditServiceApi(DatasetApiResource): + @service_api_ns.expect(document_metadata_parser) + @service_api_ns.doc("update_documents_metadata") + @service_api_ns.doc(description="Update metadata for multiple documents") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Documents metadata updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): + """Update metadata for multiple documents.""" dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - parser = reqparse.RequestParser() - parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") - args = parser.parse_args() + args = document_metadata_parser.parse_args() metadata_args = MetadataOperationData(**args) MetadataService.update_documents_metadata(dataset, metadata_args) return 200 - - -api.add_resource(DatasetMetadataCreateServiceApi, "/datasets//metadata") -api.add_resource(DatasetMetadataServiceApi, "/datasets//metadata/") -api.add_resource(DatasetMetadataBuiltInFieldServiceApi, "/datasets/metadata/built-in") -api.add_resource( - DatasetMetadataBuiltInFieldActionServiceApi, "/datasets//metadata/built-in/" -) -api.add_resource(DocumentMetadataEditServiceApi, "/datasets//documents/metadata") diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index ba8461c7c..f5e2010ca 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -3,7 +3,7 @@ from flask_login import current_user from flask_restx import marshal, reqparse from werkzeug.exceptions import NotFound -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.wraps import ( DatasetApiResource, @@ -19,34 +19,59 @@ from fields.segment_fields import child_chunk_fields, segment_fields from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs -from services.errors.chunk import ( - ChildChunkDeleteIndexError, - ChildChunkIndexingError, -) -from services.errors.chunk import ( - ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError, -) -from services.errors.chunk import ( - ChildChunkIndexingError as ChildChunkIndexingServiceError, -) +from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError +from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError +from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError + +# Define parsers for segment operations +segment_create_parser = reqparse.RequestParser() +segment_create_parser.add_argument("segments", type=list, required=False, nullable=True, location="json") + +segment_list_parser = reqparse.RequestParser() +segment_list_parser.add_argument("status", type=str, action="append", default=[], location="args") +segment_list_parser.add_argument("keyword", type=str, default=None, location="args") + +segment_update_parser = reqparse.RequestParser() +segment_update_parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") + +child_chunk_create_parser = reqparse.RequestParser() +child_chunk_create_parser.add_argument("content", type=str, required=True, nullable=False, location="json") + +child_chunk_list_parser = reqparse.RequestParser() +child_chunk_list_parser.add_argument("limit", type=int, default=20, location="args") +child_chunk_list_parser.add_argument("keyword", type=str, default=None, location="args") +child_chunk_list_parser.add_argument("page", type=int, default=1, location="args") + +child_chunk_update_parser = reqparse.RequestParser() +child_chunk_update_parser.add_argument("content", type=str, required=True, nullable=False, location="json") +@service_api_ns.route("/datasets//documents//segments") class SegmentApi(DatasetApiResource): """Resource for segments.""" + @service_api_ns.expect(segment_create_parser) + @service_api_ns.doc("create_segments") + @service_api_ns.doc(description="Create segments in a document") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Segments created successfully", + 400: "Bad request - segments data is missing", + 401: "Unauthorized - invalid API token", + 404: "Dataset or document not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, document_id): + def post(self, tenant_id: str, dataset_id: str, document_id: str): """Create single segment.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") @@ -71,9 +96,7 @@ class SegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # validate args - parser = reqparse.RequestParser() - parser.add_argument("segments", type=list, required=False, nullable=True, location="json") - args = parser.parse_args() + args = segment_create_parser.parse_args() if args["segments"] is not None: for args_item in args["segments"]: SegmentService.segment_create_args_validate(args_item, document) @@ -82,18 +105,26 @@ class SegmentApi(DatasetApiResource): else: return {"error": "Segments is required"}, 400 - def get(self, tenant_id, dataset_id, document_id): + @service_api_ns.expect(segment_list_parser) + @service_api_ns.doc("list_segments") + @service_api_ns.doc(description="List segments in a document") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Segments retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset or document not found", + } + ) + def get(self, tenant_id: str, dataset_id: str, document_id: str): """Get segments.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") @@ -114,10 +145,7 @@ class SegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) - parser = reqparse.RequestParser() - parser.add_argument("status", type=str, action="append", default=[], location="args") - parser.add_argument("keyword", type=str, default=None, location="args") - args = parser.parse_args() + args = segment_list_parser.parse_args() segments, total = SegmentService.get_segments( document_id=document_id, @@ -140,43 +168,62 @@ class SegmentApi(DatasetApiResource): return response, 200 +@service_api_ns.route("/datasets//documents//segments/") class DatasetSegmentApi(DatasetApiResource): + @service_api_ns.doc("delete_segment") + @service_api_ns.doc(description="Delete a specific segment") + @service_api_ns.doc( + params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to delete"} + ) + @service_api_ns.doc( + responses={ + 204: "Segment deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def delete(self, tenant_id, dataset_id, document_id, segment_id): + def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) return 204 + @service_api_ns.expect(segment_update_parser) + @service_api_ns.doc("update_segment") + @service_api_ns.doc(description="Update a specific segment") + @service_api_ns.doc( + params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to update"} + ) + @service_api_ns.doc( + responses={ + 200: "Segment updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, document_id, segment_id): + def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") @@ -197,37 +244,39 @@ class DatasetSegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") # validate args - parser = reqparse.RequestParser() - parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") - args = parser.parse_args() + args = segment_update_parser.parse_args() updated_segment = SegmentService.update_segment( SegmentUpdateArgs(**args["segment"]), segment, document, dataset ) return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 - def get(self, tenant_id, dataset_id, document_id, segment_id): + @service_api_ns.doc("get_segment") + @service_api_ns.doc(description="Get a specific segment by ID") + @service_api_ns.doc( + responses={ + 200: "Segment retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) + def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -235,29 +284,41 @@ class DatasetSegmentApi(DatasetApiResource): return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 +@service_api_ns.route( + "/datasets//documents//segments//child_chunks" +) class ChildChunkApi(DatasetApiResource): """Resource for child chunks.""" + @service_api_ns.expect(child_chunk_create_parser) + @service_api_ns.doc("create_child_chunk") + @service_api_ns.doc(description="Create a new child chunk for a segment") + @service_api_ns.doc( + params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"} + ) + @service_api_ns.doc( + responses={ + 200: "Child chunk created successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, document_id, segment_id): + def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): """Create child chunk.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -280,43 +341,46 @@ class ChildChunkApi(DatasetApiResource): raise ProviderNotInitializeError(ex.description) # validate args - parser = reqparse.RequestParser() - parser.add_argument("content", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() + args = child_chunk_create_parser.parse_args() try: - child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) + child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 - def get(self, tenant_id, dataset_id, document_id, segment_id): + @service_api_ns.expect(child_chunk_list_parser) + @service_api_ns.doc("list_child_chunks") + @service_api_ns.doc(description="List child chunks for a segment") + @service_api_ns.doc( + params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"} + ) + @service_api_ns.doc( + responses={ + 200: "Child chunks retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) + def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): """Get child chunks.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") - parser = reqparse.RequestParser() - parser.add_argument("limit", type=int, default=20, location="args") - parser.add_argument("keyword", type=str, default=None, location="args") - parser.add_argument("page", type=int, default=1, location="args") - args = parser.parse_args() + args = child_chunk_list_parser.parse_args() page = args["page"] limit = min(args["limit"], 100) @@ -333,28 +397,44 @@ class ChildChunkApi(DatasetApiResource): }, 200 +@service_api_ns.route( + "/datasets//documents//segments//child_chunks/" +) class DatasetChildChunkApi(DatasetApiResource): """Resource for updating child chunks.""" + @service_api_ns.doc("delete_child_chunk") + @service_api_ns.doc(description="Delete a specific child chunk") + @service_api_ns.doc( + params={ + "dataset_id": "Dataset ID", + "document_id": "Document ID", + "segment_id": "Parent segment ID", + "child_chunk_id": "Child chunk ID to delete", + } + ) + @service_api_ns.doc( + responses={ + 204: "Child chunk deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, segment, or child chunk not found", + } + ) @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): + def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): """Delete child chunk.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -364,7 +444,6 @@ class DatasetChildChunkApi(DatasetApiResource): raise NotFound("Document not found.") # check child chunk - child_chunk_id = str(child_chunk_id) child_chunk = SegmentService.get_child_chunk_by_id( child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id ) @@ -382,14 +461,30 @@ class DatasetChildChunkApi(DatasetApiResource): return 204 + @service_api_ns.expect(child_chunk_update_parser) + @service_api_ns.doc("update_child_chunk") + @service_api_ns.doc(description="Update a specific child chunk") + @service_api_ns.doc( + params={ + "dataset_id": "Dataset ID", + "document_id": "Document ID", + "segment_id": "Parent segment ID", + "child_chunk_id": "Child chunk ID to update", + } + ) + @service_api_ns.doc( + responses={ + 200: "Child chunk updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, segment, or child chunk not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): + def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): """Update child chunk.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") @@ -420,28 +515,11 @@ class DatasetChildChunkApi(DatasetApiResource): raise NotFound("Child chunk not found.") # validate args - parser = reqparse.RequestParser() - parser.add_argument("content", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() + args = child_chunk_update_parser.parse_args() try: - child_chunk = SegmentService.update_child_chunk( - args.get("content"), child_chunk, segment, document, dataset - ) + child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 - - -api.add_resource(SegmentApi, "/datasets//documents//segments") -api.add_resource( - DatasetSegmentApi, "/datasets//documents//segments/" -) -api.add_resource( - ChildChunkApi, "/datasets//documents//segments//child_chunks" -) -api.add_resource( - DatasetChildChunkApi, - "/datasets//documents//segments//child_chunks/", -) diff --git a/api/controllers/service_api/dataset/upload_file.py b/api/controllers/service_api/dataset/upload_file.py index 3b4721b5b..27b36a640 100644 --- a/api/controllers/service_api/dataset/upload_file.py +++ b/api/controllers/service_api/dataset/upload_file.py @@ -1,6 +1,6 @@ from werkzeug.exceptions import NotFound -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import ( DatasetApiResource, ) @@ -11,9 +11,23 @@ from models.model import UploadFile from services.dataset_service import DocumentService +@service_api_ns.route("/datasets//documents//upload-file") class UploadFileApi(DatasetApiResource): + @service_api_ns.doc("get_upload_file") + @service_api_ns.doc(description="Get upload file information and download URL") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Upload file information retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or upload file not found", + } + ) def get(self, tenant_id, dataset_id, document_id): - """Get upload file.""" + """Get upload file information and download URL. + + Returns information about an uploaded file including its download URL. + """ # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) @@ -49,6 +63,3 @@ class UploadFileApi(DatasetApiResource): "created_by": upload_file.created_by, "created_at": upload_file.created_at.timestamp(), }, 200 - - -api.add_resource(UploadFileApi, "/datasets//documents//upload-file") diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index 343b946c5..a9d2d6fad 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -1,9 +1,10 @@ from flask_restx import Resource from configs import dify_config -from controllers.service_api import api +from controllers.service_api import service_api_ns +@service_api_ns.route("/") class IndexApi(Resource): def get(self): return { @@ -11,6 +12,3 @@ class IndexApi(Resource): "api_version": "v1", "server_version": dify_config.project.version, } - - -api.add_resource(IndexApi, "/") diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 59ba2c7df..536cf81a2 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -1,21 +1,32 @@ from flask_login import current_user from flask_restx import Resource -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token from core.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService +@service_api_ns.route("/workspaces/current/models/model-types/") class ModelProviderAvailableModelApi(Resource): + @service_api_ns.doc("get_available_models") + @service_api_ns.doc(description="Get available models by model type") + @service_api_ns.doc(params={"model_type": "Type of model to retrieve"}) + @service_api_ns.doc( + responses={ + 200: "Models retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_dataset_token def get(self, _, model_type): + """Get available models by model type. + + Returns a list of available models for the specified model type. + """ tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) return jsonable_encoder({"data": models}) - - -api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/") diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index a09c69dcb..38835d5ac 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,4 +1,4 @@ -from flask_restx import fields +from flask_restx import Api, Namespace, fields from libs.helper import TimestampField @@ -11,6 +11,12 @@ annotation_fields = { # 'account': fields.Nested(simple_account_fields, allow_null=True) } + +def build_annotation_model(api_or_ns: Api | Namespace): + """Build the annotation model for the API or Namespace.""" + return api_or_ns.model("Annotation", annotation_fields) + + annotation_list_fields = { "data": fields.List(fields.Nested(annotation_fields)), } diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 8d8a7f604..ecc267cf3 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,4 +1,4 @@ -from flask_restx import fields +from flask_restx import Api, Namespace, fields from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -45,6 +45,12 @@ message_file_fields = { "upload_file_id": fields.String(default=None), } + +def build_message_file_model(api_or_ns: Api | Namespace): + """Build the message file fields for the API or Namespace.""" + return api_or_ns.model("MessageFile", message_file_fields) + + agent_thought_fields = { "id": fields.String, "chain_id": fields.String, @@ -209,3 +215,22 @@ conversation_infinite_scroll_pagination_fields = { "has_more": fields.Boolean, "data": fields.List(fields.Nested(simple_conversation_fields)), } + + +def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): + """Build the conversation infinite scroll pagination model for the API or Namespace.""" + simple_conversation_model = build_simple_conversation_model(api_or_ns) + + copied_fields = conversation_infinite_scroll_pagination_fields.copy() + copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model)) + return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields) + + +def build_conversation_delete_model(api_or_ns: Api | Namespace): + """Build the conversation delete model for the API or Namespace.""" + return api_or_ns.model("ConversationDelete", conversation_delete_fields) + + +def build_simple_conversation_model(api_or_ns: Api | Namespace): + """Build the simple conversation model for the API or Namespace.""" + return api_or_ns.model("SimpleConversation", simple_conversation_fields) diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 0b0f1f372..7d5e31159 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,4 +1,4 @@ -from flask_restx import fields +from flask_restx import Api, Namespace, fields from libs.helper import TimestampField @@ -27,3 +27,19 @@ conversation_variable_infinite_scroll_pagination_fields = { "has_more": fields.Boolean, "data": fields.List(fields.Nested(conversation_variable_fields)), } + + +def build_conversation_variable_model(api_or_ns: Api | Namespace): + """Build the conversation variable model for the API or Namespace.""" + return api_or_ns.model("ConversationVariable", conversation_variable_fields) + + +def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): + """Build the conversation variable infinite scroll pagination model for the API or Namespace.""" + # Build the nested variable model first + conversation_variable_model = build_conversation_variable_model(api_or_ns) + + copied_fields = conversation_variable_infinite_scroll_pagination_fields.copy() + copied_fields["data"] = fields.List(fields.Nested(conversation_variable_model)) + + return api_or_ns.model("ConversationVariableInfiniteScrollPagination", copied_fields) diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index 63b77688a..ea43e3b5f 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,4 @@ -from flask_restx import fields +from flask_restx import Api, Namespace, fields simple_end_user_fields = { "id": fields.String, @@ -6,3 +6,7 @@ simple_end_user_fields = { "is_anonymous": fields.Boolean, "session_id": fields.String, } + + +def build_simple_end_user_model(api_or_ns: Api | Namespace): + return api_or_ns.model("SimpleEndUser", simple_end_user_fields) diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 3b453293c..dd359e2f5 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,4 +1,4 @@ -from flask_restx import fields +from flask_restx import Api, Namespace, fields from libs.helper import TimestampField @@ -11,6 +11,19 @@ upload_config_fields = { "workflow_file_upload_limit": fields.Integer, } + +def build_upload_config_model(api_or_ns: Api | Namespace): + """Build the upload config model for the API or Namespace. + + Args: + api_or_ns: Flask-RestX Api or Namespace instance + + Returns: + The registered model + """ + return api_or_ns.model("UploadConfig", upload_config_fields) + + file_fields = { "id": fields.String, "name": fields.String, @@ -22,12 +35,37 @@ file_fields = { "preview_url": fields.String, } + +def build_file_model(api_or_ns: Api | Namespace): + """Build the file model for the API or Namespace. + + Args: + api_or_ns: Flask-RestX Api or Namespace instance + + Returns: + The registered model + """ + return api_or_ns.model("File", file_fields) + + remote_file_info_fields = { "file_type": fields.String(attribute="file_type"), "file_length": fields.Integer(attribute="file_length"), } +def build_remote_file_info_model(api_or_ns: Api | Namespace): + """Build the remote file info model for the API or Namespace. + + Args: + api_or_ns: Flask-RestX Api or Namespace instance + + Returns: + The registered model + """ + return api_or_ns.model("RemoteFileInfo", remote_file_info_fields) + + file_fields_with_signed_url = { "id": fields.String, "name": fields.String, @@ -38,3 +76,15 @@ file_fields_with_signed_url = { "created_by": fields.String, "created_at": TimestampField, } + + +def build_file_with_signed_url_model(api_or_ns: Api | Namespace): + """Build the file with signed URL model for the API or Namespace. + + Args: + api_or_ns: Flask-RestX Api or Namespace instance + + Returns: + The registered model + """ + return api_or_ns.model("FileWithSignedUrl", file_fields_with_signed_url) diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 2cd95504f..08e38a693 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,8 +1,17 @@ -from flask_restx import fields +from flask_restx import Api, Namespace, fields from libs.helper import AvatarUrlField, TimestampField -simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String} +simple_account_fields = { + "id": fields.String, + "name": fields.String, + "email": fields.String, +} + + +def build_simple_account_model(api_or_ns: Api | Namespace): + return api_or_ns.model("SimpleAccount", simple_account_fields) + account_fields = { "id": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index fe254a32e..a419da2e1 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -1,11 +1,19 @@ -from flask_restx import fields +from flask_restx import Api, Namespace, fields from fields.conversation_fields import message_file_fields from libs.helper import TimestampField from .raws import FilesContainedField -feedback_fields = {"rating": fields.String} +feedback_fields = { + "rating": fields.String, +} + + +def build_feedback_model(api_or_ns: Api | Namespace): + """Build the feedback model for the API or Namespace.""" + return api_or_ns.model("Feedback", feedback_fields) + agent_thought_fields = { "id": fields.String, @@ -21,6 +29,12 @@ agent_thought_fields = { "files": fields.List(fields.String), } + +def build_agent_thought_model(api_or_ns: Api | Namespace): + """Build the agent thought model for the API or Namespace.""" + return api_or_ns.model("AgentThought", agent_thought_fields) + + retriever_resource_fields = { "id": fields.String, "message_id": fields.String, diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index 0ba5bfbb7..d5b7c86a0 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,3 +1,12 @@ -from flask_restx import fields +from flask_restx import Api, Namespace, fields -tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String} +dataset_tag_fields = { + "id": fields.String, + "name": fields.String, + "type": fields.String, + "binding_count": fields.String, +} + + +def build_dataset_tag_fields(api_or_ns: Api | Namespace): + return api_or_ns.model("DataSetTag", dataset_tag_fields) diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index a60448375..243efd817 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,8 +1,8 @@ -from flask_restx import fields +from flask_restx import Api, Namespace, fields -from fields.end_user_fields import simple_end_user_fields -from fields.member_fields import simple_account_fields -from fields.workflow_run_fields import workflow_run_for_log_fields +from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields +from fields.member_fields import build_simple_account_model, simple_account_fields +from fields.workflow_run_fields import build_workflow_run_for_log_model, workflow_run_for_log_fields from libs.helper import TimestampField workflow_app_log_partial_fields = { @@ -15,6 +15,24 @@ workflow_app_log_partial_fields = { "created_at": TimestampField, } + +def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace): + """Build the workflow app log partial model for the API or Namespace.""" + workflow_run_model = build_workflow_run_for_log_model(api_or_ns) + simple_account_model = build_simple_account_model(api_or_ns) + simple_end_user_model = build_simple_end_user_model(api_or_ns) + + copied_fields = workflow_app_log_partial_fields.copy() + copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True) + copied_fields["created_by_account"] = fields.Nested( + simple_account_model, attribute="created_by_account", allow_null=True + ) + copied_fields["created_by_end_user"] = fields.Nested( + simple_end_user_model, attribute="created_by_end_user", allow_null=True + ) + return api_or_ns.model("WorkflowAppLogPartial", copied_fields) + + workflow_app_log_pagination_fields = { "page": fields.Integer, "limit": fields.Integer, @@ -22,3 +40,13 @@ workflow_app_log_pagination_fields = { "has_more": fields.Boolean, "data": fields.List(fields.Nested(workflow_app_log_partial_fields)), } + + +def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace): + """Build the workflow app log pagination model for the API or Namespace.""" + # Build the nested partial model first + workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns) + + copied_fields = workflow_app_log_pagination_fields.copy() + copied_fields["data"] = fields.List(fields.Nested(workflow_app_log_partial_model)) + return api_or_ns.model("WorkflowAppLogPagination", copied_fields) diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 5fb3a1102..6462d8ce5 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,4 +1,4 @@ -from flask_restx import fields +from flask_restx import Api, Namespace, fields from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -17,6 +17,11 @@ workflow_run_for_log_fields = { "exceptions_count": fields.Integer, } + +def build_workflow_run_for_log_model(api_or_ns: Api | Namespace): + return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields) + + workflow_run_for_list_fields = { "id": fields.String, "version": fields.String,