fix: change the mcp server strucutre to support github copilot (#24788)
This commit is contained in:
@@ -1,18 +1,27 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from flask import Response
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from controllers.console.app.mcp_server import AppMCPServerStatus
|
from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||||
from controllers.mcp import mcp_ns
|
from controllers.mcp import mcp_ns
|
||||||
from core.app.app_config.entities import VariableEntity
|
from core.app.app_config.entities import VariableEntity
|
||||||
from core.mcp import types
|
from core.mcp import types as mcp_types
|
||||||
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
|
from core.mcp.server.streamable_http import handle_mcp_request
|
||||||
from core.mcp.types import ClientNotification, ClientRequest
|
|
||||||
from core.mcp.utils import create_mcp_error_response
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.model import App, AppMCPServer, AppMode
|
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||||
|
|
||||||
|
|
||||||
|
class MCPRequestError(Exception):
|
||||||
|
"""Custom exception for MCP request processing errors"""
|
||||||
|
|
||||||
|
def __init__(self, error_code: int, message: str):
|
||||||
|
self.error_code = error_code
|
||||||
|
self.message = message
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
def int_or_str(value):
|
def int_or_str(value):
|
||||||
@@ -63,77 +72,173 @@ class MCPAppApi(Resource):
|
|||||||
Raises:
|
Raises:
|
||||||
ValidationError: Invalid request format or parameters
|
ValidationError: Invalid request format or parameters
|
||||||
"""
|
"""
|
||||||
# Parse and validate all arguments
|
|
||||||
args = mcp_request_parser.parse_args()
|
args = mcp_request_parser.parse_args()
|
||||||
|
|
||||||
request_id: Optional[Union[int, str]] = args.get("id")
|
request_id: Optional[Union[int, str]] = args.get("id")
|
||||||
|
mcp_request = self._parse_mcp_request(args)
|
||||||
|
|
||||||
server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
if not server:
|
# Get MCP server and app
|
||||||
return helper.compact_generate_response(
|
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
|
||||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
|
self._validate_server_status(mcp_server)
|
||||||
)
|
|
||||||
|
|
||||||
if server.status != AppMCPServerStatus.ACTIVE:
|
# Get user input form
|
||||||
return helper.compact_generate_response(
|
user_input_form = self._get_user_input_form(app)
|
||||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
|
|
||||||
)
|
|
||||||
|
|
||||||
app = db.session.query(App).where(App.id == server.app_id).first()
|
# Handle notification vs request differently
|
||||||
|
return self._process_mcp_message(mcp_request, request_id, app, mcp_server, user_input_form, session)
|
||||||
|
|
||||||
|
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
|
||||||
|
"""Get and validate MCP server and app in one query session"""
|
||||||
|
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
||||||
|
if not mcp_server:
|
||||||
|
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
|
||||||
|
|
||||||
|
app = session.query(App).where(App.id == mcp_server.app_id).first()
|
||||||
if not app:
|
if not app:
|
||||||
return helper.compact_generate_response(
|
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
|
||||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
|
|
||||||
)
|
|
||||||
|
|
||||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
return mcp_server, app
|
||||||
workflow = app.workflow
|
|
||||||
if workflow is None:
|
|
||||||
return helper.compact_generate_response(
|
|
||||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
|
||||||
)
|
|
||||||
|
|
||||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
def _validate_server_status(self, mcp_server: AppMCPServer) -> None:
|
||||||
|
"""Validate MCP server status"""
|
||||||
|
if mcp_server.status != AppMCPServerStatus.ACTIVE:
|
||||||
|
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
|
||||||
|
|
||||||
|
def _process_mcp_message(
|
||||||
|
self,
|
||||||
|
mcp_request: mcp_types.ClientRequest | mcp_types.ClientNotification,
|
||||||
|
request_id: Optional[Union[int, str]],
|
||||||
|
app: App,
|
||||||
|
mcp_server: AppMCPServer,
|
||||||
|
user_input_form: list[VariableEntity],
|
||||||
|
session: Session,
|
||||||
|
) -> Response:
|
||||||
|
"""Process MCP message (notification or request)"""
|
||||||
|
if isinstance(mcp_request, mcp_types.ClientNotification):
|
||||||
|
return self._handle_notification(mcp_request)
|
||||||
else:
|
else:
|
||||||
app_model_config = app.app_model_config
|
return self._handle_request(mcp_request, request_id, app, mcp_server, user_input_form, session)
|
||||||
if app_model_config is None:
|
|
||||||
return helper.compact_generate_response(
|
|
||||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
|
||||||
)
|
|
||||||
|
|
||||||
features_dict = app_model_config.to_dict()
|
def _handle_notification(self, mcp_request: mcp_types.ClientNotification) -> Response:
|
||||||
user_input_form = features_dict.get("user_input_form", [])
|
"""Handle MCP notification"""
|
||||||
converted_user_input_form: list[VariableEntity] = []
|
# For notifications, only support init notification
|
||||||
try:
|
if mcp_request.root.method != "notifications/initialized":
|
||||||
for item in user_input_form:
|
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Invalid notification method")
|
||||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
# Return HTTP 202 Accepted for notifications (no response body)
|
||||||
variable = item[variable_type]
|
return Response("", status=202, content_type="application/json")
|
||||||
converted_user_input_form.append(
|
|
||||||
VariableEntity(
|
|
||||||
type=variable_type,
|
|
||||||
variable=variable.get("variable"),
|
|
||||||
description=variable.get("description") or "",
|
|
||||||
label=variable.get("label"),
|
|
||||||
required=variable.get("required", False),
|
|
||||||
max_length=variable.get("max_length"),
|
|
||||||
options=variable.get("options") or [],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except ValidationError as e:
|
|
||||||
return helper.compact_generate_response(
|
|
||||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def _handle_request(
|
||||||
|
self,
|
||||||
|
mcp_request: mcp_types.ClientRequest,
|
||||||
|
request_id: Optional[Union[int, str]],
|
||||||
|
app: App,
|
||||||
|
mcp_server: AppMCPServer,
|
||||||
|
user_input_form: list[VariableEntity],
|
||||||
|
session: Session,
|
||||||
|
) -> Response:
|
||||||
|
"""Handle MCP request"""
|
||||||
|
if request_id is None:
|
||||||
|
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Request ID is required")
|
||||||
|
|
||||||
|
result = self._handle_mcp_request(app, mcp_server, mcp_request, user_input_form, session, request_id)
|
||||||
|
if result is None:
|
||||||
|
# This shouldn't happen for requests, but handle gracefully
|
||||||
|
raise MCPRequestError(mcp_types.INTERNAL_ERROR, "No response generated for request")
|
||||||
|
|
||||||
|
return helper.compact_generate_response(result.model_dump(by_alias=True, mode="json", exclude_none=True))
|
||||||
|
|
||||||
|
def _get_user_input_form(self, app: App) -> list[VariableEntity]:
|
||||||
|
"""Get and convert user input form"""
|
||||||
|
# Get raw user input form based on app mode
|
||||||
|
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||||
|
if not app.workflow:
|
||||||
|
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
|
||||||
|
raw_user_input_form = app.workflow.user_input_form(to_old_structure=True)
|
||||||
|
else:
|
||||||
|
if not app.app_model_config:
|
||||||
|
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
|
||||||
|
features_dict = app.app_model_config.to_dict()
|
||||||
|
raw_user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
|
# Convert to VariableEntity objects
|
||||||
try:
|
try:
|
||||||
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
|
return self._convert_user_input_form(raw_user_input_form)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
|
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||||
|
|
||||||
|
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
|
||||||
|
"""Convert raw user input form to VariableEntity objects"""
|
||||||
|
return [self._create_variable_entity(item) for item in raw_form]
|
||||||
|
|
||||||
|
def _create_variable_entity(self, item: dict) -> VariableEntity:
|
||||||
|
"""Create a single VariableEntity from raw form item"""
|
||||||
|
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||||
|
variable = item[variable_type]
|
||||||
|
|
||||||
|
return VariableEntity(
|
||||||
|
type=variable_type,
|
||||||
|
variable=variable.get("variable"),
|
||||||
|
description=variable.get("description") or "",
|
||||||
|
label=variable.get("label"),
|
||||||
|
required=variable.get("required", False),
|
||||||
|
max_length=variable.get("max_length"),
|
||||||
|
options=variable.get("options") or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||||
|
"""Parse and validate MCP request"""
|
||||||
|
try:
|
||||||
|
return mcp_types.ClientRequest.model_validate(args)
|
||||||
|
except ValidationError:
|
||||||
try:
|
try:
|
||||||
notification = ClientNotification.model_validate(args)
|
return mcp_types.ClientNotification.model_validate(args)
|
||||||
request = notification
|
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
return helper.compact_generate_response(
|
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
|
||||||
)
|
|
||||||
|
|
||||||
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
|
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
|
||||||
response = mcp_server_handler.handle()
|
"""Get end user from existing session - optimized query"""
|
||||||
return helper.compact_generate_response(response)
|
return (
|
||||||
|
session.query(EndUser)
|
||||||
|
.where(EndUser.tenant_id == tenant_id)
|
||||||
|
.where(EndUser.session_id == mcp_server_id)
|
||||||
|
.where(EndUser.type == "mcp")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_end_user(
|
||||||
|
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
|
||||||
|
) -> EndUser:
|
||||||
|
"""Create end user in existing session"""
|
||||||
|
end_user = EndUser(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
app_id=app_id,
|
||||||
|
type="mcp",
|
||||||
|
name=client_name,
|
||||||
|
session_id=mcp_server_id,
|
||||||
|
)
|
||||||
|
session.add(end_user)
|
||||||
|
session.flush() # Use flush instead of commit to keep transaction open
|
||||||
|
session.refresh(end_user)
|
||||||
|
return end_user
|
||||||
|
|
||||||
|
def _handle_mcp_request(
|
||||||
|
self,
|
||||||
|
app: App,
|
||||||
|
mcp_server: AppMCPServer,
|
||||||
|
mcp_request: mcp_types.ClientRequest,
|
||||||
|
user_input_form: list[VariableEntity],
|
||||||
|
session: Session,
|
||||||
|
request_id: Union[int, str],
|
||||||
|
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
|
||||||
|
"""Handle MCP request and return response"""
|
||||||
|
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
|
||||||
|
|
||||||
|
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
||||||
|
client_info = mcp_request.root.params.clientInfo
|
||||||
|
client_name = f"{client_info.name}@{client_info.version}"
|
||||||
|
# Commit the session before creating end user to avoid transaction conflicts
|
||||||
|
session.commit()
|
||||||
|
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
|
||||||
|
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
|
||||||
|
|
||||||
|
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)
|
||||||
|
@@ -4,224 +4,259 @@ from collections.abc import Mapping
|
|||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.web.passport import generate_session_id
|
|
||||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||||
from core.mcp import types
|
from core.mcp import types as mcp_types
|
||||||
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
|
|
||||||
from core.mcp.utils import create_mcp_error_response
|
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MCPServerStreamableHTTPRequestHandler:
|
def handle_mcp_request(
|
||||||
|
app: App,
|
||||||
|
request: mcp_types.ClientRequest,
|
||||||
|
user_input_form: list[VariableEntity],
|
||||||
|
mcp_server: AppMCPServer,
|
||||||
|
end_user: EndUser | None = None,
|
||||||
|
request_id: int | str = 1,
|
||||||
|
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError:
|
||||||
"""
|
"""
|
||||||
Apply to MCP HTTP streamable server with stateless http
|
Handle MCP request and return JSON-RPC response
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: The Dify app instance
|
||||||
|
request: The JSON-RPC request message
|
||||||
|
user_input_form: List of variable entities for the app
|
||||||
|
mcp_server: The MCP server configuration
|
||||||
|
end_user: Optional end user
|
||||||
|
request_id: The request ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON-RPC response or error
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
request_type = type(request.root)
|
||||||
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
|
|
||||||
):
|
|
||||||
self.app = app
|
|
||||||
self.request = request
|
|
||||||
mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
|
|
||||||
if not mcp_server:
|
|
||||||
raise ValueError("MCP server not found")
|
|
||||||
self.mcp_server: AppMCPServer = mcp_server
|
|
||||||
self.end_user = self.retrieve_end_user()
|
|
||||||
self.user_input_form = user_input_form
|
|
||||||
|
|
||||||
@property
|
def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
|
||||||
def request_type(self):
|
"""Create success response with business result data"""
|
||||||
return type(self.request.root)
|
return mcp_types.JSONRPCResponse(
|
||||||
|
jsonrpc="2.0",
|
||||||
|
id=request_id,
|
||||||
|
result=result_data.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
def create_error_response(code: int, message: str) -> mcp_types.JSONRPCError:
|
||||||
def parameter_schema(self):
|
"""Create error response with error code and message"""
|
||||||
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
|
from core.mcp.types import ErrorData
|
||||||
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
|
||||||
return {
|
error_data = ErrorData(code=code, message=message)
|
||||||
"type": "object",
|
return mcp_types.JSONRPCError(
|
||||||
"properties": parameters,
|
jsonrpc="2.0",
|
||||||
"required": required,
|
id=request_id,
|
||||||
}
|
error=error_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request handler mapping using functional approach
|
||||||
|
request_handlers = {
|
||||||
|
mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description),
|
||||||
|
mcp_types.ListToolsRequest: lambda: handle_list_tools(
|
||||||
|
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
|
||||||
|
),
|
||||||
|
mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user),
|
||||||
|
mcp_types.PingRequest: lambda: handle_ping(),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Dispatch request to appropriate handler
|
||||||
|
handler = request_handlers.get(request_type)
|
||||||
|
if handler:
|
||||||
|
return create_success_response(handler())
|
||||||
|
else:
|
||||||
|
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.exception("Invalid params")
|
||||||
|
return create_error_response(mcp_types.INVALID_PARAMS, str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Internal server error")
|
||||||
|
return create_error_response(mcp_types.INTERNAL_ERROR, "Internal server error: " + str(e))
|
||||||
|
|
||||||
|
|
||||||
|
def handle_ping() -> mcp_types.EmptyResult:
|
||||||
|
"""Handle ping request"""
|
||||||
|
return mcp_types.EmptyResult()
|
||||||
|
|
||||||
|
|
||||||
|
def handle_initialize(description: str) -> mcp_types.InitializeResult:
|
||||||
|
"""Handle initialize request"""
|
||||||
|
capabilities = mcp_types.ServerCapabilities(
|
||||||
|
tools=mcp_types.ToolsCapability(listChanged=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
return mcp_types.InitializeResult(
|
||||||
|
protocolVersion=mcp_types.SERVER_LATEST_PROTOCOL_VERSION,
|
||||||
|
capabilities=capabilities,
|
||||||
|
serverInfo=mcp_types.Implementation(name="Dify", version=dify_config.project.version),
|
||||||
|
instructions=description,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_list_tools(
|
||||||
|
app_name: str,
|
||||||
|
app_mode: str,
|
||||||
|
user_input_form: list[VariableEntity],
|
||||||
|
description: str,
|
||||||
|
parameters_dict: dict[str, str],
|
||||||
|
) -> mcp_types.ListToolsResult:
|
||||||
|
"""Handle list tools request"""
|
||||||
|
parameter_schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
|
||||||
|
|
||||||
|
return mcp_types.ListToolsResult(
|
||||||
|
tools=[
|
||||||
|
mcp_types.Tool(
|
||||||
|
name=app_name,
|
||||||
|
description=description,
|
||||||
|
inputSchema=parameter_schema,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_call_tool(
|
||||||
|
app: App,
|
||||||
|
request: mcp_types.ClientRequest,
|
||||||
|
user_input_form: list[VariableEntity],
|
||||||
|
end_user: EndUser | None,
|
||||||
|
) -> mcp_types.CallToolResult:
|
||||||
|
"""Handle call tool request"""
|
||||||
|
request_obj = cast(mcp_types.CallToolRequest, request.root)
|
||||||
|
args = prepare_tool_arguments(app, request_obj.params.arguments or {})
|
||||||
|
|
||||||
|
if not end_user:
|
||||||
|
raise ValueError("End user not found")
|
||||||
|
|
||||||
|
response = AppGenerateService.generate(
|
||||||
|
app,
|
||||||
|
end_user,
|
||||||
|
args,
|
||||||
|
InvokeFrom.SERVICE_API,
|
||||||
|
streaming=app.mode == AppMode.AGENT_CHAT.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
answer = extract_answer_from_response(app, response)
|
||||||
|
return mcp_types.CallToolResult(content=[mcp_types.TextContent(text=answer, type="text")])
|
||||||
|
|
||||||
|
|
||||||
|
def build_parameter_schema(
|
||||||
|
app_mode: str,
|
||||||
|
user_input_form: list[VariableEntity],
|
||||||
|
parameters_dict: dict[str, str],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build parameter schema for the tool"""
|
||||||
|
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||||
|
|
||||||
|
if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
||||||
return {
|
return {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": parameters,
|
||||||
"query": {"type": "string", "description": "User Input/Question content"},
|
"required": required,
|
||||||
**parameters,
|
|
||||||
},
|
|
||||||
"required": ["query", *required],
|
|
||||||
}
|
}
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "description": "User Input/Question content"},
|
||||||
|
**parameters,
|
||||||
|
},
|
||||||
|
"required": ["query", *required],
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
|
||||||
def capabilities(self):
|
|
||||||
return types.ServerCapabilities(
|
|
||||||
tools=types.ToolsCapability(listChanged=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
def response(self, response: types.Result | str):
|
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||||
if isinstance(response, str):
|
"""Prepare arguments based on app mode"""
|
||||||
sse_content = f"event: ping\ndata: {response}\n\n".encode()
|
if app.mode == AppMode.WORKFLOW.value:
|
||||||
yield sse_content
|
return {"inputs": arguments}
|
||||||
return
|
elif app.mode == AppMode.COMPLETION.value:
|
||||||
json_response = types.JSONRPCResponse(
|
return {"query": "", "inputs": arguments}
|
||||||
jsonrpc="2.0",
|
else:
|
||||||
id=(self.request.root.model_extra or {}).get("id", 1),
|
# Chat modes - create a copy to avoid modifying original dict
|
||||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
args_copy = arguments.copy()
|
||||||
)
|
query = args_copy.pop("query", "")
|
||||||
json_data = json.dumps(jsonable_encoder(json_response))
|
return {"query": query, "inputs": args_copy}
|
||||||
|
|
||||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
|
||||||
|
|
||||||
yield sse_content
|
def extract_answer_from_response(app: App, response: Any) -> str:
|
||||||
|
"""Extract answer from app generate response"""
|
||||||
|
answer = ""
|
||||||
|
|
||||||
def error_response(self, code: int, message: str, data=None):
|
if isinstance(response, RateLimitGenerator):
|
||||||
request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
|
answer = process_streaming_response(response)
|
||||||
return create_mcp_error_response(request_id, code, message, data)
|
elif isinstance(response, Mapping):
|
||||||
|
answer = process_mapping_response(app, response)
|
||||||
|
else:
|
||||||
|
logger.warning("Unexpected response type: %s", type(response))
|
||||||
|
|
||||||
def handle(self):
|
return answer
|
||||||
handle_map = {
|
|
||||||
types.InitializeRequest: self.initialize,
|
|
||||||
types.ListToolsRequest: self.list_tools,
|
|
||||||
types.CallToolRequest: self.invoke_tool,
|
|
||||||
types.InitializedNotification: self.handle_notification,
|
|
||||||
types.PingRequest: self.handle_ping,
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
if self.request_type in handle_map:
|
|
||||||
return self.response(handle_map[self.request_type]())
|
|
||||||
else:
|
|
||||||
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
|
|
||||||
except ValueError as e:
|
|
||||||
logger.exception("Invalid params")
|
|
||||||
return self.error_response(INVALID_PARAMS, str(e))
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Internal server error")
|
|
||||||
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
|
|
||||||
|
|
||||||
def handle_notification(self):
|
|
||||||
return "ping"
|
|
||||||
|
|
||||||
def handle_ping(self):
|
def process_streaming_response(response: RateLimitGenerator) -> str:
|
||||||
return types.EmptyResult()
|
"""Process streaming response for agent chat mode"""
|
||||||
|
answer = ""
|
||||||
def initialize(self):
|
for item in response.generator:
|
||||||
request = cast(types.InitializeRequest, self.request.root)
|
if isinstance(item, str) and item.startswith("data: "):
|
||||||
client_info = request.params.clientInfo
|
|
||||||
client_name = f"{client_info.name}@{client_info.version}"
|
|
||||||
if not self.end_user:
|
|
||||||
end_user = EndUser(
|
|
||||||
tenant_id=self.app.tenant_id,
|
|
||||||
app_id=self.app.id,
|
|
||||||
type="mcp",
|
|
||||||
name=client_name,
|
|
||||||
session_id=generate_session_id(),
|
|
||||||
external_user_id=self.mcp_server.id,
|
|
||||||
)
|
|
||||||
db.session.add(end_user)
|
|
||||||
db.session.commit()
|
|
||||||
return types.InitializeResult(
|
|
||||||
protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
|
|
||||||
capabilities=self.capabilities,
|
|
||||||
serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
|
|
||||||
instructions=self.mcp_server.description,
|
|
||||||
)
|
|
||||||
|
|
||||||
def list_tools(self):
|
|
||||||
if not self.end_user:
|
|
||||||
raise ValueError("User not found")
|
|
||||||
return types.ListToolsResult(
|
|
||||||
tools=[
|
|
||||||
types.Tool(
|
|
||||||
name=self.app.name,
|
|
||||||
description=self.mcp_server.description,
|
|
||||||
inputSchema=self.parameter_schema,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke_tool(self):
|
|
||||||
if not self.end_user:
|
|
||||||
raise ValueError("User not found")
|
|
||||||
request = cast(types.CallToolRequest, self.request.root)
|
|
||||||
args = request.params.arguments or {}
|
|
||||||
if self.app.mode in {AppMode.WORKFLOW.value}:
|
|
||||||
args = {"inputs": args}
|
|
||||||
elif self.app.mode in {AppMode.COMPLETION.value}:
|
|
||||||
args = {"query": "", "inputs": args}
|
|
||||||
else:
|
|
||||||
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
|
|
||||||
response = AppGenerateService.generate(
|
|
||||||
self.app,
|
|
||||||
self.end_user,
|
|
||||||
args,
|
|
||||||
InvokeFrom.SERVICE_API,
|
|
||||||
streaming=self.app.mode == AppMode.AGENT_CHAT.value,
|
|
||||||
)
|
|
||||||
answer = ""
|
|
||||||
if isinstance(response, RateLimitGenerator):
|
|
||||||
for item in response.generator:
|
|
||||||
data = item
|
|
||||||
if isinstance(data, str) and data.startswith("data: "):
|
|
||||||
try:
|
|
||||||
json_str = data[6:].strip()
|
|
||||||
parsed_data = json.loads(json_str)
|
|
||||||
if parsed_data.get("event") == "agent_thought":
|
|
||||||
answer += parsed_data.get("thought", "")
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
if isinstance(response, Mapping):
|
|
||||||
if self.app.mode in {
|
|
||||||
AppMode.ADVANCED_CHAT.value,
|
|
||||||
AppMode.COMPLETION.value,
|
|
||||||
AppMode.CHAT.value,
|
|
||||||
AppMode.AGENT_CHAT.value,
|
|
||||||
}:
|
|
||||||
answer = response["answer"]
|
|
||||||
elif self.app.mode in {AppMode.WORKFLOW.value}:
|
|
||||||
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid app mode")
|
|
||||||
# Not support image yet
|
|
||||||
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
|
|
||||||
|
|
||||||
def retrieve_end_user(self):
|
|
||||||
return (
|
|
||||||
db.session.query(EndUser)
|
|
||||||
.where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
|
|
||||||
parameters: dict[str, dict[str, Any]] = {}
|
|
||||||
required = []
|
|
||||||
for item in user_input_form:
|
|
||||||
parameters[item.variable] = {}
|
|
||||||
if item.type in (
|
|
||||||
VariableEntityType.FILE,
|
|
||||||
VariableEntityType.FILE_LIST,
|
|
||||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
if item.required:
|
|
||||||
required.append(item.variable)
|
|
||||||
# if the workflow republished, the parameters not changed
|
|
||||||
# we should not raise error here
|
|
||||||
try:
|
try:
|
||||||
description = self.mcp_server.parameters_dict[item.variable]
|
json_str = item[6:].strip()
|
||||||
except KeyError:
|
parsed_data = json.loads(json_str)
|
||||||
description = ""
|
if parsed_data.get("event") == "agent_thought":
|
||||||
parameters[item.variable]["description"] = description
|
answer += parsed_data.get("thought", "")
|
||||||
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
except json.JSONDecodeError:
|
||||||
parameters[item.variable]["type"] = "string"
|
continue
|
||||||
elif item.type == VariableEntityType.SELECT:
|
return answer
|
||||||
parameters[item.variable]["type"] = "string"
|
|
||||||
parameters[item.variable]["enum"] = item.options
|
|
||||||
elif item.type == VariableEntityType.NUMBER:
|
def process_mapping_response(app: App, response: Mapping) -> str:
|
||||||
parameters[item.variable]["type"] = "float"
|
"""Process mapping response based on app mode"""
|
||||||
return parameters, required
|
if app.mode in {
|
||||||
|
AppMode.ADVANCED_CHAT.value,
|
||||||
|
AppMode.COMPLETION.value,
|
||||||
|
AppMode.CHAT.value,
|
||||||
|
AppMode.AGENT_CHAT.value,
|
||||||
|
}:
|
||||||
|
return response.get("answer", "")
|
||||||
|
elif app.mode == AppMode.WORKFLOW.value:
|
||||||
|
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid app mode: " + str(app.mode))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_input_form_to_parameters(
|
||||||
|
user_input_form: list[VariableEntity],
|
||||||
|
parameters_dict: dict[str, str],
|
||||||
|
) -> tuple[dict[str, dict[str, Any]], list[str]]:
|
||||||
|
"""Convert user input form to parameter schema"""
|
||||||
|
parameters: dict[str, dict[str, Any]] = {}
|
||||||
|
required = []
|
||||||
|
|
||||||
|
for item in user_input_form:
|
||||||
|
if item.type in (
|
||||||
|
VariableEntityType.FILE,
|
||||||
|
VariableEntityType.FILE_LIST,
|
||||||
|
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
parameters[item.variable] = {}
|
||||||
|
if item.required:
|
||||||
|
required.append(item.variable)
|
||||||
|
# if the workflow republished, the parameters not changed
|
||||||
|
# we should not raise error here
|
||||||
|
description = parameters_dict.get(item.variable, "")
|
||||||
|
parameters[item.variable]["description"] = description
|
||||||
|
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||||
|
parameters[item.variable]["type"] = "string"
|
||||||
|
elif item.type == VariableEntityType.SELECT:
|
||||||
|
parameters[item.variable]["type"] = "string"
|
||||||
|
parameters[item.variable]["enum"] = item.options
|
||||||
|
elif item.type == VariableEntityType.NUMBER:
|
||||||
|
parameters[item.variable]["type"] = "float"
|
||||||
|
return parameters, required
|
||||||
|
@@ -138,5 +138,5 @@ def create_mcp_error_response(
|
|||||||
error=error_data,
|
error=error_data,
|
||||||
)
|
)
|
||||||
json_data = json.dumps(jsonable_encoder(json_response))
|
json_data = json.dumps(jsonable_encoder(json_response))
|
||||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
sse_content = json_data.encode()
|
||||||
yield sse_content
|
yield sse_content
|
||||||
|
1
api/tests/unit_tests/core/mcp/server/__init__.py
Normal file
1
api/tests/unit_tests/core/mcp/server/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# MCP server tests
|
449
api/tests/unit_tests/core/mcp/server/test_streamable_http.py
Normal file
449
api/tests/unit_tests/core/mcp/server/test_streamable_http.py
Normal file
@@ -0,0 +1,449 @@
|
|||||||
|
import json
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||||
|
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||||
|
from core.mcp import types
|
||||||
|
from core.mcp.server.streamable_http import (
|
||||||
|
build_parameter_schema,
|
||||||
|
convert_input_form_to_parameters,
|
||||||
|
extract_answer_from_response,
|
||||||
|
handle_call_tool,
|
||||||
|
handle_initialize,
|
||||||
|
handle_list_tools,
|
||||||
|
handle_mcp_request,
|
||||||
|
handle_ping,
|
||||||
|
prepare_tool_arguments,
|
||||||
|
process_mapping_response,
|
||||||
|
)
|
||||||
|
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandleMCPRequest:
|
||||||
|
"""Test handle_mcp_request function"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Setup test fixtures"""
|
||||||
|
self.app = Mock(spec=App)
|
||||||
|
self.app.name = "test_app"
|
||||||
|
self.app.mode = AppMode.CHAT.value
|
||||||
|
|
||||||
|
self.mcp_server = Mock(spec=AppMCPServer)
|
||||||
|
self.mcp_server.description = "Test server"
|
||||||
|
self.mcp_server.parameters_dict = {}
|
||||||
|
|
||||||
|
self.end_user = Mock(spec=EndUser)
|
||||||
|
self.user_input_form = []
|
||||||
|
|
||||||
|
# Create mock request
|
||||||
|
self.mock_request = Mock()
|
||||||
|
self.mock_request.root = Mock()
|
||||||
|
self.mock_request.root.id = 123
|
||||||
|
|
||||||
|
def test_handle_ping_request(self):
|
||||||
|
"""Test handling ping request"""
|
||||||
|
# Setup ping request
|
||||||
|
self.mock_request.root = Mock(spec=types.PingRequest)
|
||||||
|
self.mock_request.root.id = 123
|
||||||
|
request_type = Mock(return_value=types.PingRequest)
|
||||||
|
|
||||||
|
with patch("core.mcp.server.streamable_http.type", request_type):
|
||||||
|
result = handle_mcp_request(
|
||||||
|
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, types.JSONRPCResponse)
|
||||||
|
assert result.jsonrpc == "2.0"
|
||||||
|
assert result.id == 123
|
||||||
|
|
||||||
|
def test_handle_initialize_request(self):
|
||||||
|
"""Test handling initialize request"""
|
||||||
|
# Setup initialize request
|
||||||
|
self.mock_request.root = Mock(spec=types.InitializeRequest)
|
||||||
|
self.mock_request.root.id = 123
|
||||||
|
request_type = Mock(return_value=types.InitializeRequest)
|
||||||
|
|
||||||
|
with patch("core.mcp.server.streamable_http.type", request_type):
|
||||||
|
result = handle_mcp_request(
|
||||||
|
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, types.JSONRPCResponse)
|
||||||
|
assert result.jsonrpc == "2.0"
|
||||||
|
assert result.id == 123
|
||||||
|
|
||||||
|
def test_handle_list_tools_request(self):
|
||||||
|
"""Test handling list tools request"""
|
||||||
|
# Setup list tools request
|
||||||
|
self.mock_request.root = Mock(spec=types.ListToolsRequest)
|
||||||
|
self.mock_request.root.id = 123
|
||||||
|
request_type = Mock(return_value=types.ListToolsRequest)
|
||||||
|
|
||||||
|
with patch("core.mcp.server.streamable_http.type", request_type):
|
||||||
|
result = handle_mcp_request(
|
||||||
|
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, types.JSONRPCResponse)
|
||||||
|
assert result.jsonrpc == "2.0"
|
||||||
|
assert result.id == 123
|
||||||
|
|
||||||
|
@patch("core.mcp.server.streamable_http.AppGenerateService")
|
||||||
|
def test_handle_call_tool_request(self, mock_app_generate):
|
||||||
|
"""Test handling call tool request"""
|
||||||
|
# Setup call tool request
|
||||||
|
mock_call_request = Mock(spec=types.CallToolRequest)
|
||||||
|
mock_call_request.params = Mock()
|
||||||
|
mock_call_request.params.arguments = {"query": "test question"}
|
||||||
|
mock_call_request.id = 123
|
||||||
|
|
||||||
|
self.mock_request.root = mock_call_request
|
||||||
|
request_type = Mock(return_value=types.CallToolRequest)
|
||||||
|
|
||||||
|
# Mock app generate service response
|
||||||
|
mock_response = {"answer": "test answer"}
|
||||||
|
mock_app_generate.generate.return_value = mock_response
|
||||||
|
|
||||||
|
with patch("core.mcp.server.streamable_http.type", request_type):
|
||||||
|
result = handle_mcp_request(
|
||||||
|
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, types.JSONRPCResponse)
|
||||||
|
assert result.jsonrpc == "2.0"
|
||||||
|
assert result.id == 123
|
||||||
|
|
||||||
|
# Verify AppGenerateService was called
|
||||||
|
mock_app_generate.generate.assert_called_once()
|
||||||
|
|
||||||
|
def test_handle_unknown_request_type(self):
|
||||||
|
"""Test handling unknown request type"""
|
||||||
|
|
||||||
|
# Setup unknown request
|
||||||
|
class UnknownRequest:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.mock_request.root = Mock(spec=UnknownRequest)
|
||||||
|
self.mock_request.root.id = 123
|
||||||
|
request_type = Mock(return_value=UnknownRequest)
|
||||||
|
|
||||||
|
with patch("core.mcp.server.streamable_http.type", request_type):
|
||||||
|
result = handle_mcp_request(
|
||||||
|
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, types.JSONRPCError)
|
||||||
|
assert result.jsonrpc == "2.0"
|
||||||
|
assert result.id == 123
|
||||||
|
assert result.error.code == types.METHOD_NOT_FOUND
|
||||||
|
|
||||||
|
def test_handle_value_error(self):
|
||||||
|
"""Test handling ValueError"""
|
||||||
|
# Setup request that will cause ValueError
|
||||||
|
self.mock_request.root = Mock(spec=types.CallToolRequest)
|
||||||
|
self.mock_request.root.params = Mock()
|
||||||
|
self.mock_request.root.params.arguments = {}
|
||||||
|
|
||||||
|
request_type = Mock(return_value=types.CallToolRequest)
|
||||||
|
|
||||||
|
# Don't provide end_user to cause ValueError
|
||||||
|
with patch("core.mcp.server.streamable_http.type", request_type):
|
||||||
|
result = handle_mcp_request(self.app, self.mock_request, self.user_input_form, self.mcp_server, None, 123)
|
||||||
|
|
||||||
|
assert isinstance(result, types.JSONRPCError)
|
||||||
|
assert result.error.code == types.INVALID_PARAMS
|
||||||
|
|
||||||
|
def test_handle_generic_exception(self):
|
||||||
|
"""Test handling generic exception"""
|
||||||
|
# Setup request that will cause generic exception
|
||||||
|
self.mock_request.root = Mock(spec=types.PingRequest)
|
||||||
|
self.mock_request.root.id = 123
|
||||||
|
|
||||||
|
# Patch handle_ping to raise exception instead of type
|
||||||
|
with patch("core.mcp.server.streamable_http.handle_ping", side_effect=Exception("Test error")):
|
||||||
|
with patch("core.mcp.server.streamable_http.type", return_value=types.PingRequest):
|
||||||
|
result = handle_mcp_request(
|
||||||
|
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, types.JSONRPCError)
|
||||||
|
assert result.error.code == types.INTERNAL_ERROR
|
||||||
|
|
||||||
|
|
||||||
|
class TestIndividualHandlers:
|
||||||
|
"""Test individual handler functions"""
|
||||||
|
|
||||||
|
def test_handle_ping(self):
|
||||||
|
"""Test ping handler"""
|
||||||
|
result = handle_ping()
|
||||||
|
assert isinstance(result, types.EmptyResult)
|
||||||
|
|
||||||
|
def test_handle_initialize(self):
|
||||||
|
"""Test initialize handler"""
|
||||||
|
description = "Test server"
|
||||||
|
|
||||||
|
with patch("core.mcp.server.streamable_http.dify_config") as mock_config:
|
||||||
|
mock_config.project.version = "1.0.0"
|
||||||
|
result = handle_initialize(description)
|
||||||
|
|
||||||
|
assert isinstance(result, types.InitializeResult)
|
||||||
|
assert result.protocolVersion == types.SERVER_LATEST_PROTOCOL_VERSION
|
||||||
|
assert result.instructions == "Test server"
|
||||||
|
|
||||||
|
def test_handle_list_tools(self):
|
||||||
|
"""Test list tools handler"""
|
||||||
|
app_name = "test_app"
|
||||||
|
app_mode = AppMode.CHAT.value
|
||||||
|
description = "Test server"
|
||||||
|
parameters_dict: dict[str, str] = {}
|
||||||
|
user_input_form: list[VariableEntity] = []
|
||||||
|
|
||||||
|
result = handle_list_tools(app_name, app_mode, user_input_form, description, parameters_dict)
|
||||||
|
|
||||||
|
assert isinstance(result, types.ListToolsResult)
|
||||||
|
assert len(result.tools) == 1
|
||||||
|
assert result.tools[0].name == "test_app"
|
||||||
|
assert result.tools[0].description == "Test server"
|
||||||
|
|
||||||
|
@patch("core.mcp.server.streamable_http.AppGenerateService")
|
||||||
|
def test_handle_call_tool(self, mock_app_generate):
|
||||||
|
"""Test call tool handler"""
|
||||||
|
app = Mock(spec=App)
|
||||||
|
app.mode = AppMode.CHAT.value
|
||||||
|
|
||||||
|
# Create mock request
|
||||||
|
mock_request = Mock()
|
||||||
|
mock_call_request = Mock(spec=types.CallToolRequest)
|
||||||
|
mock_call_request.params = Mock()
|
||||||
|
mock_call_request.params.arguments = {"query": "test question"}
|
||||||
|
mock_request.root = mock_call_request
|
||||||
|
|
||||||
|
user_input_form: list[VariableEntity] = []
|
||||||
|
end_user = Mock(spec=EndUser)
|
||||||
|
|
||||||
|
# Mock app generate service response
|
||||||
|
mock_response = {"answer": "test answer"}
|
||||||
|
mock_app_generate.generate.return_value = mock_response
|
||||||
|
|
||||||
|
result = handle_call_tool(app, mock_request, user_input_form, end_user)
|
||||||
|
|
||||||
|
assert isinstance(result, types.CallToolResult)
|
||||||
|
assert len(result.content) == 1
|
||||||
|
# Type assertion needed due to union type
|
||||||
|
text_content = result.content[0]
|
||||||
|
assert hasattr(text_content, "text")
|
||||||
|
assert text_content.text == "test answer" # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
def test_handle_call_tool_no_end_user(self):
|
||||||
|
"""Test call tool handler without end user"""
|
||||||
|
app = Mock(spec=App)
|
||||||
|
mock_request = Mock()
|
||||||
|
user_input_form: list[VariableEntity] = []
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="End user not found"):
|
||||||
|
handle_call_tool(app, mock_request, user_input_form, None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtilityFunctions:
|
||||||
|
"""Test utility functions"""
|
||||||
|
|
||||||
|
def test_build_parameter_schema_chat_mode(self):
|
||||||
|
"""Test building parameter schema for chat mode"""
|
||||||
|
app_mode = AppMode.CHAT.value
|
||||||
|
parameters_dict: dict[str, str] = {"name": "Enter your name"}
|
||||||
|
|
||||||
|
user_input_form = [
|
||||||
|
VariableEntity(
|
||||||
|
type=VariableEntityType.TEXT_INPUT,
|
||||||
|
variable="name",
|
||||||
|
description="User name",
|
||||||
|
label="Name",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
|
||||||
|
|
||||||
|
assert schema["type"] == "object"
|
||||||
|
assert "query" in schema["properties"]
|
||||||
|
assert "name" in schema["properties"]
|
||||||
|
assert "query" in schema["required"]
|
||||||
|
assert "name" in schema["required"]
|
||||||
|
|
||||||
|
def test_build_parameter_schema_workflow_mode(self):
|
||||||
|
"""Test building parameter schema for workflow mode"""
|
||||||
|
app_mode = AppMode.WORKFLOW.value
|
||||||
|
parameters_dict: dict[str, str] = {"input_text": "Enter text"}
|
||||||
|
|
||||||
|
user_input_form = [
|
||||||
|
VariableEntity(
|
||||||
|
type=VariableEntityType.TEXT_INPUT,
|
||||||
|
variable="input_text",
|
||||||
|
description="Input text",
|
||||||
|
label="Input",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
|
||||||
|
|
||||||
|
assert schema["type"] == "object"
|
||||||
|
assert "query" not in schema["properties"]
|
||||||
|
assert "input_text" in schema["properties"]
|
||||||
|
assert "input_text" in schema["required"]
|
||||||
|
|
||||||
|
def test_prepare_tool_arguments_chat_mode(self):
|
||||||
|
"""Test preparing tool arguments for chat mode"""
|
||||||
|
app = Mock(spec=App)
|
||||||
|
app.mode = AppMode.CHAT.value
|
||||||
|
|
||||||
|
arguments = {"query": "test question", "name": "John"}
|
||||||
|
|
||||||
|
result = prepare_tool_arguments(app, arguments)
|
||||||
|
|
||||||
|
assert result["query"] == "test question"
|
||||||
|
assert result["inputs"]["name"] == "John"
|
||||||
|
# Original arguments should not be modified
|
||||||
|
assert arguments["query"] == "test question"
|
||||||
|
|
||||||
|
def test_prepare_tool_arguments_workflow_mode(self):
|
||||||
|
"""Test preparing tool arguments for workflow mode"""
|
||||||
|
app = Mock(spec=App)
|
||||||
|
app.mode = AppMode.WORKFLOW.value
|
||||||
|
|
||||||
|
arguments = {"input_text": "test input"}
|
||||||
|
|
||||||
|
result = prepare_tool_arguments(app, arguments)
|
||||||
|
|
||||||
|
assert "inputs" in result
|
||||||
|
assert result["inputs"]["input_text"] == "test input"
|
||||||
|
|
||||||
|
def test_prepare_tool_arguments_completion_mode(self):
|
||||||
|
"""Test preparing tool arguments for completion mode"""
|
||||||
|
app = Mock(spec=App)
|
||||||
|
app.mode = AppMode.COMPLETION.value
|
||||||
|
|
||||||
|
arguments = {"name": "John"}
|
||||||
|
|
||||||
|
result = prepare_tool_arguments(app, arguments)
|
||||||
|
|
||||||
|
assert result["query"] == ""
|
||||||
|
assert result["inputs"]["name"] == "John"
|
||||||
|
|
||||||
|
def test_extract_answer_from_mapping_response_chat(self):
|
||||||
|
"""Test extracting answer from mapping response for chat mode"""
|
||||||
|
app = Mock(spec=App)
|
||||||
|
app.mode = AppMode.CHAT.value
|
||||||
|
|
||||||
|
response = {"answer": "test answer", "other": "data"}
|
||||||
|
|
||||||
|
result = extract_answer_from_response(app, response)
|
||||||
|
|
||||||
|
assert result == "test answer"
|
||||||
|
|
||||||
|
def test_extract_answer_from_mapping_response_workflow(self):
|
||||||
|
"""Test extracting answer from mapping response for workflow mode"""
|
||||||
|
app = Mock(spec=App)
|
||||||
|
app.mode = AppMode.WORKFLOW.value
|
||||||
|
|
||||||
|
response = {"data": {"outputs": {"result": "test result"}}}
|
||||||
|
|
||||||
|
result = extract_answer_from_response(app, response)
|
||||||
|
|
||||||
|
expected = json.dumps({"result": "test result"}, ensure_ascii=False)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
def test_extract_answer_from_streaming_response(self):
|
||||||
|
"""Test extracting answer from streaming response"""
|
||||||
|
app = Mock(spec=App)
|
||||||
|
|
||||||
|
# Mock RateLimitGenerator
|
||||||
|
mock_generator = Mock(spec=RateLimitGenerator)
|
||||||
|
mock_generator.generator = [
|
||||||
|
'data: {"event": "agent_thought", "thought": "thinking..."}',
|
||||||
|
'data: {"event": "agent_thought", "thought": "more thinking"}',
|
||||||
|
'data: {"event": "other", "content": "ignore this"}',
|
||||||
|
"not data format",
|
||||||
|
]
|
||||||
|
|
||||||
|
result = extract_answer_from_response(app, mock_generator)
|
||||||
|
|
||||||
|
assert result == "thinking...more thinking"
|
||||||
|
|
||||||
|
def test_process_mapping_response_invalid_mode(self):
|
||||||
|
"""Test processing mapping response with invalid app mode"""
|
||||||
|
app = Mock(spec=App)
|
||||||
|
app.mode = "invalid_mode"
|
||||||
|
|
||||||
|
response = {"answer": "test"}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid app mode"):
|
||||||
|
process_mapping_response(app, response)
|
||||||
|
|
||||||
|
def test_convert_input_form_to_parameters(self):
|
||||||
|
"""Test converting input form to parameters"""
|
||||||
|
user_input_form = [
|
||||||
|
VariableEntity(
|
||||||
|
type=VariableEntityType.TEXT_INPUT,
|
||||||
|
variable="name",
|
||||||
|
description="User name",
|
||||||
|
label="Name",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
VariableEntity(
|
||||||
|
type=VariableEntityType.SELECT,
|
||||||
|
variable="category",
|
||||||
|
description="Category",
|
||||||
|
label="Category",
|
||||||
|
required=False,
|
||||||
|
options=["A", "B", "C"],
|
||||||
|
),
|
||||||
|
VariableEntity(
|
||||||
|
type=VariableEntityType.NUMBER,
|
||||||
|
variable="count",
|
||||||
|
description="Count",
|
||||||
|
label="Count",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
VariableEntity(
|
||||||
|
type=VariableEntityType.FILE,
|
||||||
|
variable="upload",
|
||||||
|
description="File upload",
|
||||||
|
label="Upload",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
parameters_dict: dict[str, str] = {
|
||||||
|
"name": "Enter your name",
|
||||||
|
"category": "Select category",
|
||||||
|
"count": "Enter count",
|
||||||
|
}
|
||||||
|
|
||||||
|
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||||
|
|
||||||
|
# Check parameters
|
||||||
|
assert "name" in parameters
|
||||||
|
assert parameters["name"]["type"] == "string"
|
||||||
|
assert parameters["name"]["description"] == "Enter your name"
|
||||||
|
|
||||||
|
assert "category" in parameters
|
||||||
|
assert parameters["category"]["type"] == "string"
|
||||||
|
assert parameters["category"]["enum"] == ["A", "B", "C"]
|
||||||
|
|
||||||
|
assert "count" in parameters
|
||||||
|
assert parameters["count"]["type"] == "float"
|
||||||
|
|
||||||
|
# FILE type should be skipped - it creates empty dict but gets filtered later
|
||||||
|
# Check that it doesn't have any meaningful content
|
||||||
|
if "upload" in parameters:
|
||||||
|
assert parameters["upload"] == {}
|
||||||
|
|
||||||
|
# Check required fields
|
||||||
|
assert "name" in required
|
||||||
|
assert "count" in required
|
||||||
|
assert "category" not in required
|
||||||
|
|
||||||
|
# Note: _get_request_id function has been removed as request_id is now passed as parameter
|
Reference in New Issue
Block a user