diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index fc1974901..eef9ddc76 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -1,18 +1,27 @@ from typing import Optional, Union +from flask import Response from flask_restx import Resource, reqparse from pydantic import ValidationError +from sqlalchemy.orm import Session from controllers.console.app.mcp_server import AppMCPServerStatus from controllers.mcp import mcp_ns from core.app.app_config.entities import VariableEntity -from core.mcp import types -from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler -from core.mcp.types import ClientNotification, ClientRequest -from core.mcp.utils import create_mcp_error_response +from core.mcp import types as mcp_types +from core.mcp.server.streamable_http import handle_mcp_request from extensions.ext_database import db from libs import helper -from models.model import App, AppMCPServer, AppMode +from models.model import App, AppMCPServer, AppMode, EndUser + + +class MCPRequestError(Exception): + """Custom exception for MCP request processing errors""" + + def __init__(self, error_code: int, message: str): + self.error_code = error_code + self.message = message + super().__init__(message) def int_or_str(value): @@ -63,77 +72,173 @@ class MCPAppApi(Resource): Raises: ValidationError: Invalid request format or parameters """ - # Parse and validate all arguments args = mcp_request_parser.parse_args() - request_id: Optional[Union[int, str]] = args.get("id") + mcp_request = self._parse_mcp_request(args) - server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() - if not server: - return helper.compact_generate_response( - create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found") - ) + with Session(db.engine, expire_on_commit=False) as session: + # Get MCP server and app + mcp_server, app = self._get_mcp_server_and_app(server_code, session) + self._validate_server_status(mcp_server) - if server.status != AppMCPServerStatus.ACTIVE: - return helper.compact_generate_response( - create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active") - ) + # Get user input form + user_input_form = self._get_user_input_form(app) - app = db.session.query(App).where(App.id == server.app_id).first() + # Handle notification vs request differently + return self._process_mcp_message(mcp_request, request_id, app, mcp_server, user_input_form, session) + + def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]: + """Get and validate MCP server and app in one query session""" + mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() + if not mcp_server: + raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found") + + app = session.query(App).where(App.id == mcp_server.app_id).first() if not app: - return helper.compact_generate_response( - create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found") - ) + raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found") - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: - workflow = app.workflow - if workflow is None: - return helper.compact_generate_response( - create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable") - ) + return mcp_server, app - user_input_form = workflow.user_input_form(to_old_structure=True) + def _validate_server_status(self, mcp_server: AppMCPServer) -> None: + """Validate MCP server status""" + if mcp_server.status != AppMCPServerStatus.ACTIVE: + raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active") + + def _process_mcp_message( + self, + mcp_request: mcp_types.ClientRequest | mcp_types.ClientNotification, + request_id: Optional[Union[int, str]], + app: App, + mcp_server: AppMCPServer, + user_input_form: list[VariableEntity], + session: Session, + ) -> Response: + """Process MCP message (notification or request)""" + if isinstance(mcp_request, mcp_types.ClientNotification): + return self._handle_notification(mcp_request) else: - app_model_config = app.app_model_config - if app_model_config is None: - return helper.compact_generate_response( - create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable") - ) + return self._handle_request(mcp_request, request_id, app, mcp_server, user_input_form, session) - features_dict = app_model_config.to_dict() - user_input_form = features_dict.get("user_input_form", []) - converted_user_input_form: list[VariableEntity] = [] - try: - for item in user_input_form: - variable_type = item.get("type", "") or list(item.keys())[0] - variable = item[variable_type] - converted_user_input_form.append( - VariableEntity( - type=variable_type, - variable=variable.get("variable"), - description=variable.get("description") or "", - label=variable.get("label"), - required=variable.get("required", False), - max_length=variable.get("max_length"), - options=variable.get("options") or [], - ) - ) - except ValidationError as e: - return helper.compact_generate_response( - create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}") - ) + def _handle_notification(self, mcp_request: mcp_types.ClientNotification) -> Response: + """Handle MCP notification""" + # For notifications, only support init notification + if mcp_request.root.method != "notifications/initialized": + raise MCPRequestError(mcp_types.INVALID_REQUEST, "Invalid notification method") + # Return HTTP 202 Accepted for notifications (no response body) + return Response("", status=202, content_type="application/json") + def _handle_request( + self, + mcp_request: mcp_types.ClientRequest, + request_id: Optional[Union[int, str]], + app: App, + mcp_server: AppMCPServer, + user_input_form: list[VariableEntity], + session: Session, + ) -> Response: + """Handle MCP request""" + if request_id is None: + raise MCPRequestError(mcp_types.INVALID_REQUEST, "Request ID is required") + + result = self._handle_mcp_request(app, mcp_server, mcp_request, user_input_form, session, request_id) + if result is None: + # This shouldn't happen for requests, but handle gracefully + raise MCPRequestError(mcp_types.INTERNAL_ERROR, "No response generated for request") + + return helper.compact_generate_response(result.model_dump(by_alias=True, mode="json", exclude_none=True)) + + def _get_user_input_form(self, app: App) -> list[VariableEntity]: + """Get and convert user input form""" + # Get raw user input form based on app mode + if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if not app.workflow: + raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable") + raw_user_input_form = app.workflow.user_input_form(to_old_structure=True) + else: + if not app.app_model_config: + raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable") + features_dict = app.app_model_config.to_dict() + raw_user_input_form = features_dict.get("user_input_form", []) + + # Convert to VariableEntity objects try: - request: ClientRequest | ClientNotification = ClientRequest.model_validate(args) + return self._convert_user_input_form(raw_user_input_form) except ValidationError as e: + raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}") + + def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]: + """Convert raw user input form to VariableEntity objects""" + return [self._create_variable_entity(item) for item in raw_form] + + def _create_variable_entity(self, item: dict) -> VariableEntity: + """Create a single VariableEntity from raw form item""" + variable_type = item.get("type", "") or list(item.keys())[0] + variable = item[variable_type] + + return VariableEntity( + type=variable_type, + variable=variable.get("variable"), + description=variable.get("description") or "", + label=variable.get("label"), + required=variable.get("required", False), + max_length=variable.get("max_length"), + options=variable.get("options") or [], + ) + + def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification: + """Parse and validate MCP request""" + try: + return mcp_types.ClientRequest.model_validate(args) + except ValidationError: try: - notification = ClientNotification.model_validate(args) - request = notification + return mcp_types.ClientNotification.model_validate(args) except ValidationError as e: - return helper.compact_generate_response( - create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}") - ) + raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}") - mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form) - response = mcp_server_handler.handle() - return helper.compact_generate_response(response) + def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None: + """Get end user from existing session - optimized query""" + return ( + session.query(EndUser) + .where(EndUser.tenant_id == tenant_id) + .where(EndUser.session_id == mcp_server_id) + .where(EndUser.type == "mcp") + .first() + ) + + def _create_end_user( + self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session + ) -> EndUser: + """Create end user in existing session""" + end_user = EndUser( + tenant_id=tenant_id, + app_id=app_id, + type="mcp", + name=client_name, + session_id=mcp_server_id, + ) + session.add(end_user) + session.flush() # Use flush instead of commit to keep transaction open + session.refresh(end_user) + return end_user + + def _handle_mcp_request( + self, + app: App, + mcp_server: AppMCPServer, + mcp_request: mcp_types.ClientRequest, + user_input_form: list[VariableEntity], + session: Session, + request_id: Union[int, str], + ) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None: + """Handle MCP request and return response""" + end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session) + + if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest): + client_info = mcp_request.root.params.clientInfo + client_name = f"{client_info.name}@{client_info.version}" + # Commit the session before creating end user to avoid transaction conflicts + session.commit() + with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin(): + end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session) + + return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id) diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index efe91bbff..5851c6d40 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -4,224 +4,259 @@ from collections.abc import Mapping from typing import Any, cast from configs import dify_config -from controllers.web.passport import generate_session_id from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator -from core.mcp import types -from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND -from core.mcp.utils import create_mcp_error_response -from core.model_runtime.utils.encoders import jsonable_encoder -from extensions.ext_database import db +from core.mcp import types as mcp_types from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService logger = logging.getLogger(__name__) -class MCPServerStreamableHTTPRequestHandler: +def handle_mcp_request( + app: App, + request: mcp_types.ClientRequest, + user_input_form: list[VariableEntity], + mcp_server: AppMCPServer, + end_user: EndUser | None = None, + request_id: int | str = 1, +) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError: """ - Apply to MCP HTTP streamable server with stateless http + Handle MCP request and return JSON-RPC response + + Args: + app: The Dify app instance + request: The JSON-RPC request message + user_input_form: List of variable entities for the app + mcp_server: The MCP server configuration + end_user: Optional end user + request_id: The request ID + + Returns: + JSON-RPC response or error """ - def __init__( - self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity] - ): - self.app = app - self.request = request - mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first() - if not mcp_server: - raise ValueError("MCP server not found") - self.mcp_server: AppMCPServer = mcp_server - self.end_user = self.retrieve_end_user() - self.user_input_form = user_input_form + request_type = type(request.root) - @property - def request_type(self): - return type(self.request.root) + def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse: + """Create success response with business result data""" + return mcp_types.JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + result=result_data.model_dump(by_alias=True, mode="json", exclude_none=True), + ) - @property - def parameter_schema(self): - parameters, required = self._convert_input_form_to_parameters(self.user_input_form) - if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}: - return { - "type": "object", - "properties": parameters, - "required": required, - } + def create_error_response(code: int, message: str) -> mcp_types.JSONRPCError: + """Create error response with error code and message""" + from core.mcp.types import ErrorData + + error_data = ErrorData(code=code, message=message) + return mcp_types.JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=error_data, + ) + + # Request handler mapping using functional approach + request_handlers = { + mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description), + mcp_types.ListToolsRequest: lambda: handle_list_tools( + app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict + ), + mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user), + mcp_types.PingRequest: lambda: handle_ping(), + } + + try: + # Dispatch request to appropriate handler + handler = request_handlers.get(request_type) + if handler: + return create_success_response(handler()) + else: + return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}") + + except ValueError as e: + logger.exception("Invalid params") + return create_error_response(mcp_types.INVALID_PARAMS, str(e)) + except Exception as e: + logger.exception("Internal server error") + return create_error_response(mcp_types.INTERNAL_ERROR, "Internal server error: " + str(e)) + + +def handle_ping() -> mcp_types.EmptyResult: + """Handle ping request""" + return mcp_types.EmptyResult() + + +def handle_initialize(description: str) -> mcp_types.InitializeResult: + """Handle initialize request""" + capabilities = mcp_types.ServerCapabilities( + tools=mcp_types.ToolsCapability(listChanged=False), + ) + + return mcp_types.InitializeResult( + protocolVersion=mcp_types.SERVER_LATEST_PROTOCOL_VERSION, + capabilities=capabilities, + serverInfo=mcp_types.Implementation(name="Dify", version=dify_config.project.version), + instructions=description, + ) + + +def handle_list_tools( + app_name: str, + app_mode: str, + user_input_form: list[VariableEntity], + description: str, + parameters_dict: dict[str, str], +) -> mcp_types.ListToolsResult: + """Handle list tools request""" + parameter_schema = build_parameter_schema(app_mode, user_input_form, parameters_dict) + + return mcp_types.ListToolsResult( + tools=[ + mcp_types.Tool( + name=app_name, + description=description, + inputSchema=parameter_schema, + ) + ], + ) + + +def handle_call_tool( + app: App, + request: mcp_types.ClientRequest, + user_input_form: list[VariableEntity], + end_user: EndUser | None, +) -> mcp_types.CallToolResult: + """Handle call tool request""" + request_obj = cast(mcp_types.CallToolRequest, request.root) + args = prepare_tool_arguments(app, request_obj.params.arguments or {}) + + if not end_user: + raise ValueError("End user not found") + + response = AppGenerateService.generate( + app, + end_user, + args, + InvokeFrom.SERVICE_API, + streaming=app.mode == AppMode.AGENT_CHAT.value, + ) + + answer = extract_answer_from_response(app, response) + return mcp_types.CallToolResult(content=[mcp_types.TextContent(text=answer, type="text")]) + + +def build_parameter_schema( + app_mode: str, + user_input_form: list[VariableEntity], + parameters_dict: dict[str, str], +) -> dict[str, Any]: + """Build parameter schema for the tool""" + parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict) + + if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}: return { "type": "object", - "properties": { - "query": {"type": "string", "description": "User Input/Question content"}, - **parameters, - }, - "required": ["query", *required], + "properties": parameters, + "required": required, } + return { + "type": "object", + "properties": { + "query": {"type": "string", "description": "User Input/Question content"}, + **parameters, + }, + "required": ["query", *required], + } - @property - def capabilities(self): - return types.ServerCapabilities( - tools=types.ToolsCapability(listChanged=False), - ) - def response(self, response: types.Result | str): - if isinstance(response, str): - sse_content = f"event: ping\ndata: {response}\n\n".encode() - yield sse_content - return - json_response = types.JSONRPCResponse( - jsonrpc="2.0", - id=(self.request.root.model_extra or {}).get("id", 1), - result=response.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - json_data = json.dumps(jsonable_encoder(json_response)) +def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]: + """Prepare arguments based on app mode""" + if app.mode == AppMode.WORKFLOW.value: + return {"inputs": arguments} + elif app.mode == AppMode.COMPLETION.value: + return {"query": "", "inputs": arguments} + else: + # Chat modes - create a copy to avoid modifying original dict + args_copy = arguments.copy() + query = args_copy.pop("query", "") + return {"query": query, "inputs": args_copy} - sse_content = f"event: message\ndata: {json_data}\n\n".encode() - yield sse_content +def extract_answer_from_response(app: App, response: Any) -> str: + """Extract answer from app generate response""" + answer = "" - def error_response(self, code: int, message: str, data=None): - request_id = (self.request.root.model_extra or {}).get("id", 1) or 1 - return create_mcp_error_response(request_id, code, message, data) + if isinstance(response, RateLimitGenerator): + answer = process_streaming_response(response) + elif isinstance(response, Mapping): + answer = process_mapping_response(app, response) + else: + logger.warning("Unexpected response type: %s", type(response)) - def handle(self): - handle_map = { - types.InitializeRequest: self.initialize, - types.ListToolsRequest: self.list_tools, - types.CallToolRequest: self.invoke_tool, - types.InitializedNotification: self.handle_notification, - types.PingRequest: self.handle_ping, - } - try: - if self.request_type in handle_map: - return self.response(handle_map[self.request_type]()) - else: - return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}") - except ValueError as e: - logger.exception("Invalid params") - return self.error_response(INVALID_PARAMS, str(e)) - except Exception as e: - logger.exception("Internal server error") - return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}") + return answer - def handle_notification(self): - return "ping" - def handle_ping(self): - return types.EmptyResult() - - def initialize(self): - request = cast(types.InitializeRequest, self.request.root) - client_info = request.params.clientInfo - client_name = f"{client_info.name}@{client_info.version}" - if not self.end_user: - end_user = EndUser( - tenant_id=self.app.tenant_id, - app_id=self.app.id, - type="mcp", - name=client_name, - session_id=generate_session_id(), - external_user_id=self.mcp_server.id, - ) - db.session.add(end_user) - db.session.commit() - return types.InitializeResult( - protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION, - capabilities=self.capabilities, - serverInfo=types.Implementation(name="Dify", version=dify_config.project.version), - instructions=self.mcp_server.description, - ) - - def list_tools(self): - if not self.end_user: - raise ValueError("User not found") - return types.ListToolsResult( - tools=[ - types.Tool( - name=self.app.name, - description=self.mcp_server.description, - inputSchema=self.parameter_schema, - ) - ], - ) - - def invoke_tool(self): - if not self.end_user: - raise ValueError("User not found") - request = cast(types.CallToolRequest, self.request.root) - args = request.params.arguments or {} - if self.app.mode in {AppMode.WORKFLOW.value}: - args = {"inputs": args} - elif self.app.mode in {AppMode.COMPLETION.value}: - args = {"query": "", "inputs": args} - else: - args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}} - response = AppGenerateService.generate( - self.app, - self.end_user, - args, - InvokeFrom.SERVICE_API, - streaming=self.app.mode == AppMode.AGENT_CHAT.value, - ) - answer = "" - if isinstance(response, RateLimitGenerator): - for item in response.generator: - data = item - if isinstance(data, str) and data.startswith("data: "): - try: - json_str = data[6:].strip() - parsed_data = json.loads(json_str) - if parsed_data.get("event") == "agent_thought": - answer += parsed_data.get("thought", "") - except json.JSONDecodeError: - continue - if isinstance(response, Mapping): - if self.app.mode in { - AppMode.ADVANCED_CHAT.value, - AppMode.COMPLETION.value, - AppMode.CHAT.value, - AppMode.AGENT_CHAT.value, - }: - answer = response["answer"] - elif self.app.mode in {AppMode.WORKFLOW.value}: - answer = json.dumps(response["data"]["outputs"], ensure_ascii=False) - else: - raise ValueError("Invalid app mode") - # Not support image yet - return types.CallToolResult(content=[types.TextContent(text=answer, type="text")]) - - def retrieve_end_user(self): - return ( - db.session.query(EndUser) - .where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") - .first() - ) - - def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]): - parameters: dict[str, dict[str, Any]] = {} - required = [] - for item in user_input_form: - parameters[item.variable] = {} - if item.type in ( - VariableEntityType.FILE, - VariableEntityType.FILE_LIST, - VariableEntityType.EXTERNAL_DATA_TOOL, - ): - continue - if item.required: - required.append(item.variable) - # if the workflow republished, the parameters not changed - # we should not raise error here +def process_streaming_response(response: RateLimitGenerator) -> str: + """Process streaming response for agent chat mode""" + answer = "" + for item in response.generator: + if isinstance(item, str) and item.startswith("data: "): try: - description = self.mcp_server.parameters_dict[item.variable] - except KeyError: - description = "" - parameters[item.variable]["description"] = description - if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): - parameters[item.variable]["type"] = "string" - elif item.type == VariableEntityType.SELECT: - parameters[item.variable]["type"] = "string" - parameters[item.variable]["enum"] = item.options - elif item.type == VariableEntityType.NUMBER: - parameters[item.variable]["type"] = "float" - return parameters, required + json_str = item[6:].strip() + parsed_data = json.loads(json_str) + if parsed_data.get("event") == "agent_thought": + answer += parsed_data.get("thought", "") + except json.JSONDecodeError: + continue + return answer + + +def process_mapping_response(app: App, response: Mapping) -> str: + """Process mapping response based on app mode""" + if app.mode in { + AppMode.ADVANCED_CHAT.value, + AppMode.COMPLETION.value, + AppMode.CHAT.value, + AppMode.AGENT_CHAT.value, + }: + return response.get("answer", "") + elif app.mode == AppMode.WORKFLOW.value: + return json.dumps(response["data"]["outputs"], ensure_ascii=False) + else: + raise ValueError("Invalid app mode: " + str(app.mode)) + + +def convert_input_form_to_parameters( + user_input_form: list[VariableEntity], + parameters_dict: dict[str, str], +) -> tuple[dict[str, dict[str, Any]], list[str]]: + """Convert user input form to parameter schema""" + parameters: dict[str, dict[str, Any]] = {} + required = [] + + for item in user_input_form: + if item.type in ( + VariableEntityType.FILE, + VariableEntityType.FILE_LIST, + VariableEntityType.EXTERNAL_DATA_TOOL, + ): + continue + parameters[item.variable] = {} + if item.required: + required.append(item.variable) + # if the workflow republished, the parameters not changed + # we should not raise error here + description = parameters_dict.get(item.variable, "") + parameters[item.variable]["description"] = description + if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): + parameters[item.variable]["type"] = "string" + elif item.type == VariableEntityType.SELECT: + parameters[item.variable]["type"] = "string" + parameters[item.variable]["enum"] = item.options + elif item.type == VariableEntityType.NUMBER: + parameters[item.variable]["type"] = "float" + return parameters, required diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index 80912bc4c..84bef7b93 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -138,5 +138,5 @@ def create_mcp_error_response( error=error_data, ) json_data = json.dumps(jsonable_encoder(json_response)) - sse_content = f"event: message\ndata: {json_data}\n\n".encode() + sse_content = json_data.encode() yield sse_content diff --git a/api/tests/unit_tests/core/mcp/server/__init__.py b/api/tests/unit_tests/core/mcp/server/__init__.py new file mode 100644 index 000000000..81af0ff1c --- /dev/null +++ b/api/tests/unit_tests/core/mcp/server/__init__.py @@ -0,0 +1 @@ +# MCP server tests diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py new file mode 100644 index 000000000..ccc5d42bc --- /dev/null +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -0,0 +1,449 @@ +import json +from unittest.mock import Mock, patch + +import pytest + +from core.app.app_config.entities import VariableEntity, VariableEntityType +from core.app.features.rate_limiting.rate_limit import RateLimitGenerator +from core.mcp import types +from core.mcp.server.streamable_http import ( + build_parameter_schema, + convert_input_form_to_parameters, + extract_answer_from_response, + handle_call_tool, + handle_initialize, + handle_list_tools, + handle_mcp_request, + handle_ping, + prepare_tool_arguments, + process_mapping_response, +) +from models.model import App, AppMCPServer, AppMode, EndUser + + +class TestHandleMCPRequest: + """Test handle_mcp_request function""" + + def setup_method(self): + """Setup test fixtures""" + self.app = Mock(spec=App) + self.app.name = "test_app" + self.app.mode = AppMode.CHAT.value + + self.mcp_server = Mock(spec=AppMCPServer) + self.mcp_server.description = "Test server" + self.mcp_server.parameters_dict = {} + + self.end_user = Mock(spec=EndUser) + self.user_input_form = [] + + # Create mock request + self.mock_request = Mock() + self.mock_request.root = Mock() + self.mock_request.root.id = 123 + + def test_handle_ping_request(self): + """Test handling ping request""" + # Setup ping request + self.mock_request.root = Mock(spec=types.PingRequest) + self.mock_request.root.id = 123 + request_type = Mock(return_value=types.PingRequest) + + with patch("core.mcp.server.streamable_http.type", request_type): + result = handle_mcp_request( + self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123 + ) + + assert isinstance(result, types.JSONRPCResponse) + assert result.jsonrpc == "2.0" + assert result.id == 123 + + def test_handle_initialize_request(self): + """Test handling initialize request""" + # Setup initialize request + self.mock_request.root = Mock(spec=types.InitializeRequest) + self.mock_request.root.id = 123 + request_type = Mock(return_value=types.InitializeRequest) + + with patch("core.mcp.server.streamable_http.type", request_type): + result = handle_mcp_request( + self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123 + ) + + assert isinstance(result, types.JSONRPCResponse) + assert result.jsonrpc == "2.0" + assert result.id == 123 + + def test_handle_list_tools_request(self): + """Test handling list tools request""" + # Setup list tools request + self.mock_request.root = Mock(spec=types.ListToolsRequest) + self.mock_request.root.id = 123 + request_type = Mock(return_value=types.ListToolsRequest) + + with patch("core.mcp.server.streamable_http.type", request_type): + result = handle_mcp_request( + self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123 + ) + + assert isinstance(result, types.JSONRPCResponse) + assert result.jsonrpc == "2.0" + assert result.id == 123 + + @patch("core.mcp.server.streamable_http.AppGenerateService") + def test_handle_call_tool_request(self, mock_app_generate): + """Test handling call tool request""" + # Setup call tool request + mock_call_request = Mock(spec=types.CallToolRequest) + mock_call_request.params = Mock() + mock_call_request.params.arguments = {"query": "test question"} + mock_call_request.id = 123 + + self.mock_request.root = mock_call_request + request_type = Mock(return_value=types.CallToolRequest) + + # Mock app generate service response + mock_response = {"answer": "test answer"} + mock_app_generate.generate.return_value = mock_response + + with patch("core.mcp.server.streamable_http.type", request_type): + result = handle_mcp_request( + self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123 + ) + + assert isinstance(result, types.JSONRPCResponse) + assert result.jsonrpc == "2.0" + assert result.id == 123 + + # Verify AppGenerateService was called + mock_app_generate.generate.assert_called_once() + + def test_handle_unknown_request_type(self): + """Test handling unknown request type""" + + # Setup unknown request + class UnknownRequest: + pass + + self.mock_request.root = Mock(spec=UnknownRequest) + self.mock_request.root.id = 123 + request_type = Mock(return_value=UnknownRequest) + + with patch("core.mcp.server.streamable_http.type", request_type): + result = handle_mcp_request( + self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123 + ) + + assert isinstance(result, types.JSONRPCError) + assert result.jsonrpc == "2.0" + assert result.id == 123 + assert result.error.code == types.METHOD_NOT_FOUND + + def test_handle_value_error(self): + """Test handling ValueError""" + # Setup request that will cause ValueError + self.mock_request.root = Mock(spec=types.CallToolRequest) + self.mock_request.root.params = Mock() + self.mock_request.root.params.arguments = {} + + request_type = Mock(return_value=types.CallToolRequest) + + # Don't provide end_user to cause ValueError + with patch("core.mcp.server.streamable_http.type", request_type): + result = handle_mcp_request(self.app, self.mock_request, self.user_input_form, self.mcp_server, None, 123) + + assert isinstance(result, types.JSONRPCError) + assert result.error.code == types.INVALID_PARAMS + + def test_handle_generic_exception(self): + """Test handling generic exception""" + # Setup request that will cause generic exception + self.mock_request.root = Mock(spec=types.PingRequest) + self.mock_request.root.id = 123 + + # Patch handle_ping to raise exception instead of type + with patch("core.mcp.server.streamable_http.handle_ping", side_effect=Exception("Test error")): + with patch("core.mcp.server.streamable_http.type", return_value=types.PingRequest): + result = handle_mcp_request( + self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123 + ) + + assert isinstance(result, types.JSONRPCError) + assert result.error.code == types.INTERNAL_ERROR + + +class TestIndividualHandlers: + """Test individual handler functions""" + + def test_handle_ping(self): + """Test ping handler""" + result = handle_ping() + assert isinstance(result, types.EmptyResult) + + def test_handle_initialize(self): + """Test initialize handler""" + description = "Test server" + + with patch("core.mcp.server.streamable_http.dify_config") as mock_config: + mock_config.project.version = "1.0.0" + result = handle_initialize(description) + + assert isinstance(result, types.InitializeResult) + assert result.protocolVersion == types.SERVER_LATEST_PROTOCOL_VERSION + assert result.instructions == "Test server" + + def test_handle_list_tools(self): + """Test list tools handler""" + app_name = "test_app" + app_mode = AppMode.CHAT.value + description = "Test server" + parameters_dict: dict[str, str] = {} + user_input_form: list[VariableEntity] = [] + + result = handle_list_tools(app_name, app_mode, user_input_form, description, parameters_dict) + + assert isinstance(result, types.ListToolsResult) + assert len(result.tools) == 1 + assert result.tools[0].name == "test_app" + assert result.tools[0].description == "Test server" + + @patch("core.mcp.server.streamable_http.AppGenerateService") + def test_handle_call_tool(self, mock_app_generate): + """Test call tool handler""" + app = Mock(spec=App) + app.mode = AppMode.CHAT.value + + # Create mock request + mock_request = Mock() + mock_call_request = Mock(spec=types.CallToolRequest) + mock_call_request.params = Mock() + mock_call_request.params.arguments = {"query": "test question"} + mock_request.root = mock_call_request + + user_input_form: list[VariableEntity] = [] + end_user = Mock(spec=EndUser) + + # Mock app generate service response + mock_response = {"answer": "test answer"} + mock_app_generate.generate.return_value = mock_response + + result = handle_call_tool(app, mock_request, user_input_form, end_user) + + assert isinstance(result, types.CallToolResult) + assert len(result.content) == 1 + # Type assertion needed due to union type + text_content = result.content[0] + assert hasattr(text_content, "text") + assert text_content.text == "test answer" # type: ignore[attr-defined] + + def test_handle_call_tool_no_end_user(self): + """Test call tool handler without end user""" + app = Mock(spec=App) + mock_request = Mock() + user_input_form: list[VariableEntity] = [] + + with pytest.raises(ValueError, match="End user not found"): + handle_call_tool(app, mock_request, user_input_form, None) + + +class TestUtilityFunctions: + """Test utility functions""" + + def test_build_parameter_schema_chat_mode(self): + """Test building parameter schema for chat mode""" + app_mode = AppMode.CHAT.value + parameters_dict: dict[str, str] = {"name": "Enter your name"} + + user_input_form = [ + VariableEntity( + type=VariableEntityType.TEXT_INPUT, + variable="name", + description="User name", + label="Name", + required=True, + ) + ] + + schema = build_parameter_schema(app_mode, user_input_form, parameters_dict) + + assert schema["type"] == "object" + assert "query" in schema["properties"] + assert "name" in schema["properties"] + assert "query" in schema["required"] + assert "name" in schema["required"] + + def test_build_parameter_schema_workflow_mode(self): + """Test building parameter schema for workflow mode""" + app_mode = AppMode.WORKFLOW.value + parameters_dict: dict[str, str] = {"input_text": "Enter text"} + + user_input_form = [ + VariableEntity( + type=VariableEntityType.TEXT_INPUT, + variable="input_text", + description="Input text", + label="Input", + required=True, + ) + ] + + schema = build_parameter_schema(app_mode, user_input_form, parameters_dict) + + assert schema["type"] == "object" + assert "query" not in schema["properties"] + assert "input_text" in schema["properties"] + assert "input_text" in schema["required"] + + def test_prepare_tool_arguments_chat_mode(self): + """Test preparing tool arguments for chat mode""" + app = Mock(spec=App) + app.mode = AppMode.CHAT.value + + arguments = {"query": "test question", "name": "John"} + + result = prepare_tool_arguments(app, arguments) + + assert result["query"] == "test question" + assert result["inputs"]["name"] == "John" + # Original arguments should not be modified + assert arguments["query"] == "test question" + + def test_prepare_tool_arguments_workflow_mode(self): + """Test preparing tool arguments for workflow mode""" + app = Mock(spec=App) + app.mode = AppMode.WORKFLOW.value + + arguments = {"input_text": "test input"} + + result = prepare_tool_arguments(app, arguments) + + assert "inputs" in result + assert result["inputs"]["input_text"] == "test input" + + def test_prepare_tool_arguments_completion_mode(self): + """Test preparing tool arguments for completion mode""" + app = Mock(spec=App) + app.mode = AppMode.COMPLETION.value + + arguments = {"name": "John"} + + result = prepare_tool_arguments(app, arguments) + + assert result["query"] == "" + assert result["inputs"]["name"] == "John" + + def test_extract_answer_from_mapping_response_chat(self): + """Test extracting answer from mapping response for chat mode""" + app = Mock(spec=App) + app.mode = AppMode.CHAT.value + + response = {"answer": "test answer", "other": "data"} + + result = extract_answer_from_response(app, response) + + assert result == "test answer" + + def test_extract_answer_from_mapping_response_workflow(self): + """Test extracting answer from mapping response for workflow mode""" + app = Mock(spec=App) + app.mode = AppMode.WORKFLOW.value + + response = {"data": {"outputs": {"result": "test result"}}} + + result = extract_answer_from_response(app, response) + + expected = json.dumps({"result": "test result"}, ensure_ascii=False) + assert result == expected + + def test_extract_answer_from_streaming_response(self): + """Test extracting answer from streaming response""" + app = Mock(spec=App) + + # Mock RateLimitGenerator + mock_generator = Mock(spec=RateLimitGenerator) + mock_generator.generator = [ + 'data: {"event": "agent_thought", "thought": "thinking..."}', + 'data: {"event": "agent_thought", "thought": "more thinking"}', + 'data: {"event": "other", "content": "ignore this"}', + "not data format", + ] + + result = extract_answer_from_response(app, mock_generator) + + assert result == "thinking...more thinking" + + def test_process_mapping_response_invalid_mode(self): + """Test processing mapping response with invalid app mode""" + app = Mock(spec=App) + app.mode = "invalid_mode" + + response = {"answer": "test"} + + with pytest.raises(ValueError, match="Invalid app mode"): + process_mapping_response(app, response) + + def test_convert_input_form_to_parameters(self): + """Test converting input form to parameters""" + user_input_form = [ + VariableEntity( + type=VariableEntityType.TEXT_INPUT, + variable="name", + description="User name", + label="Name", + required=True, + ), + VariableEntity( + type=VariableEntityType.SELECT, + variable="category", + description="Category", + label="Category", + required=False, + options=["A", "B", "C"], + ), + VariableEntity( + type=VariableEntityType.NUMBER, + variable="count", + description="Count", + label="Count", + required=True, + ), + VariableEntity( + type=VariableEntityType.FILE, + variable="upload", + description="File upload", + label="Upload", + required=False, + ), + ] + + parameters_dict: dict[str, str] = { + "name": "Enter your name", + "category": "Select category", + "count": "Enter count", + } + + parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict) + + # Check parameters + assert "name" in parameters + assert parameters["name"]["type"] == "string" + assert parameters["name"]["description"] == "Enter your name" + + assert "category" in parameters + assert parameters["category"]["type"] == "string" + assert parameters["category"]["enum"] == ["A", "B", "C"] + + assert "count" in parameters + assert parameters["count"]["type"] == "float" + + # FILE type should be skipped - it creates empty dict but gets filtered later + # Check that it doesn't have any meaningful content + if "upload" in parameters: + assert parameters["upload"] == {} + + # Check required fields + assert "name" in required + assert "count" in required + assert "category" not in required + + # Note: _get_request_id function has been removed as request_id is now passed as parameter