fix: change the mcp server strucutre to support github copilot (#24788)
This commit is contained in:
@@ -4,224 +4,259 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web.passport import generate_session_id
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.mcp import types
|
||||
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
|
||||
from core.mcp.utils import create_mcp_error_response
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from core.mcp import types as mcp_types
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPServerStreamableHTTPRequestHandler:
|
||||
def handle_mcp_request(
|
||||
app: App,
|
||||
request: mcp_types.ClientRequest,
|
||||
user_input_form: list[VariableEntity],
|
||||
mcp_server: AppMCPServer,
|
||||
end_user: EndUser | None = None,
|
||||
request_id: int | str = 1,
|
||||
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError:
|
||||
"""
|
||||
Apply to MCP HTTP streamable server with stateless http
|
||||
Handle MCP request and return JSON-RPC response
|
||||
|
||||
Args:
|
||||
app: The Dify app instance
|
||||
request: The JSON-RPC request message
|
||||
user_input_form: List of variable entities for the app
|
||||
mcp_server: The MCP server configuration
|
||||
end_user: Optional end user
|
||||
request_id: The request ID
|
||||
|
||||
Returns:
|
||||
JSON-RPC response or error
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
|
||||
):
|
||||
self.app = app
|
||||
self.request = request
|
||||
mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
|
||||
if not mcp_server:
|
||||
raise ValueError("MCP server not found")
|
||||
self.mcp_server: AppMCPServer = mcp_server
|
||||
self.end_user = self.retrieve_end_user()
|
||||
self.user_input_form = user_input_form
|
||||
request_type = type(request.root)
|
||||
|
||||
@property
|
||||
def request_type(self):
|
||||
return type(self.request.root)
|
||||
def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
|
||||
"""Create success response with business result data"""
|
||||
return mcp_types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
result=result_data.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
@property
|
||||
def parameter_schema(self):
|
||||
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
|
||||
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
}
|
||||
def create_error_response(code: int, message: str) -> mcp_types.JSONRPCError:
|
||||
"""Create error response with error code and message"""
|
||||
from core.mcp.types import ErrorData
|
||||
|
||||
error_data = ErrorData(code=code, message=message)
|
||||
return mcp_types.JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
error=error_data,
|
||||
)
|
||||
|
||||
# Request handler mapping using functional approach
|
||||
request_handlers = {
|
||||
mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description),
|
||||
mcp_types.ListToolsRequest: lambda: handle_list_tools(
|
||||
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
|
||||
),
|
||||
mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user),
|
||||
mcp_types.PingRequest: lambda: handle_ping(),
|
||||
}
|
||||
|
||||
try:
|
||||
# Dispatch request to appropriate handler
|
||||
handler = request_handlers.get(request_type)
|
||||
if handler:
|
||||
return create_success_response(handler())
|
||||
else:
|
||||
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Invalid params")
|
||||
return create_error_response(mcp_types.INVALID_PARAMS, str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Internal server error")
|
||||
return create_error_response(mcp_types.INTERNAL_ERROR, "Internal server error: " + str(e))
|
||||
|
||||
|
||||
def handle_ping() -> mcp_types.EmptyResult:
|
||||
"""Handle ping request"""
|
||||
return mcp_types.EmptyResult()
|
||||
|
||||
|
||||
def handle_initialize(description: str) -> mcp_types.InitializeResult:
|
||||
"""Handle initialize request"""
|
||||
capabilities = mcp_types.ServerCapabilities(
|
||||
tools=mcp_types.ToolsCapability(listChanged=False),
|
||||
)
|
||||
|
||||
return mcp_types.InitializeResult(
|
||||
protocolVersion=mcp_types.SERVER_LATEST_PROTOCOL_VERSION,
|
||||
capabilities=capabilities,
|
||||
serverInfo=mcp_types.Implementation(name="Dify", version=dify_config.project.version),
|
||||
instructions=description,
|
||||
)
|
||||
|
||||
|
||||
def handle_list_tools(
|
||||
app_name: str,
|
||||
app_mode: str,
|
||||
user_input_form: list[VariableEntity],
|
||||
description: str,
|
||||
parameters_dict: dict[str, str],
|
||||
) -> mcp_types.ListToolsResult:
|
||||
"""Handle list tools request"""
|
||||
parameter_schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
|
||||
|
||||
return mcp_types.ListToolsResult(
|
||||
tools=[
|
||||
mcp_types.Tool(
|
||||
name=app_name,
|
||||
description=description,
|
||||
inputSchema=parameter_schema,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def handle_call_tool(
|
||||
app: App,
|
||||
request: mcp_types.ClientRequest,
|
||||
user_input_form: list[VariableEntity],
|
||||
end_user: EndUser | None,
|
||||
) -> mcp_types.CallToolResult:
|
||||
"""Handle call tool request"""
|
||||
request_obj = cast(mcp_types.CallToolRequest, request.root)
|
||||
args = prepare_tool_arguments(app, request_obj.params.arguments or {})
|
||||
|
||||
if not end_user:
|
||||
raise ValueError("End user not found")
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app,
|
||||
end_user,
|
||||
args,
|
||||
InvokeFrom.SERVICE_API,
|
||||
streaming=app.mode == AppMode.AGENT_CHAT.value,
|
||||
)
|
||||
|
||||
answer = extract_answer_from_response(app, response)
|
||||
return mcp_types.CallToolResult(content=[mcp_types.TextContent(text=answer, type="text")])
|
||||
|
||||
|
||||
def build_parameter_schema(
|
||||
app_mode: str,
|
||||
user_input_form: list[VariableEntity],
|
||||
parameters_dict: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
"""Build parameter schema for the tool"""
|
||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||
|
||||
if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "User Input/Question content"},
|
||||
**parameters,
|
||||
},
|
||||
"required": ["query", *required],
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
}
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "User Input/Question content"},
|
||||
**parameters,
|
||||
},
|
||||
"required": ["query", *required],
|
||||
}
|
||||
|
||||
@property
|
||||
def capabilities(self):
|
||||
return types.ServerCapabilities(
|
||||
tools=types.ToolsCapability(listChanged=False),
|
||||
)
|
||||
|
||||
def response(self, response: types.Result | str):
|
||||
if isinstance(response, str):
|
||||
sse_content = f"event: ping\ndata: {response}\n\n".encode()
|
||||
yield sse_content
|
||||
return
|
||||
json_response = types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=(self.request.root.model_extra or {}).get("id", 1),
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Prepare arguments based on app mode"""
|
||||
if app.mode == AppMode.WORKFLOW.value:
|
||||
return {"inputs": arguments}
|
||||
elif app.mode == AppMode.COMPLETION.value:
|
||||
return {"query": "", "inputs": arguments}
|
||||
else:
|
||||
# Chat modes - create a copy to avoid modifying original dict
|
||||
args_copy = arguments.copy()
|
||||
query = args_copy.pop("query", "")
|
||||
return {"query": query, "inputs": args_copy}
|
||||
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
|
||||
yield sse_content
|
||||
def extract_answer_from_response(app: App, response: Any) -> str:
|
||||
"""Extract answer from app generate response"""
|
||||
answer = ""
|
||||
|
||||
def error_response(self, code: int, message: str, data=None):
|
||||
request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
|
||||
return create_mcp_error_response(request_id, code, message, data)
|
||||
if isinstance(response, RateLimitGenerator):
|
||||
answer = process_streaming_response(response)
|
||||
elif isinstance(response, Mapping):
|
||||
answer = process_mapping_response(app, response)
|
||||
else:
|
||||
logger.warning("Unexpected response type: %s", type(response))
|
||||
|
||||
def handle(self):
|
||||
handle_map = {
|
||||
types.InitializeRequest: self.initialize,
|
||||
types.ListToolsRequest: self.list_tools,
|
||||
types.CallToolRequest: self.invoke_tool,
|
||||
types.InitializedNotification: self.handle_notification,
|
||||
types.PingRequest: self.handle_ping,
|
||||
}
|
||||
try:
|
||||
if self.request_type in handle_map:
|
||||
return self.response(handle_map[self.request_type]())
|
||||
else:
|
||||
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
|
||||
except ValueError as e:
|
||||
logger.exception("Invalid params")
|
||||
return self.error_response(INVALID_PARAMS, str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Internal server error")
|
||||
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
|
||||
return answer
|
||||
|
||||
def handle_notification(self):
|
||||
return "ping"
|
||||
|
||||
def handle_ping(self):
|
||||
return types.EmptyResult()
|
||||
|
||||
def initialize(self):
|
||||
request = cast(types.InitializeRequest, self.request.root)
|
||||
client_info = request.params.clientInfo
|
||||
client_name = f"{client_info.name}@{client_info.version}"
|
||||
if not self.end_user:
|
||||
end_user = EndUser(
|
||||
tenant_id=self.app.tenant_id,
|
||||
app_id=self.app.id,
|
||||
type="mcp",
|
||||
name=client_name,
|
||||
session_id=generate_session_id(),
|
||||
external_user_id=self.mcp_server.id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
return types.InitializeResult(
|
||||
protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self.capabilities,
|
||||
serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
|
||||
instructions=self.mcp_server.description,
|
||||
)
|
||||
|
||||
def list_tools(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
return types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name=self.app.name,
|
||||
description=self.mcp_server.description,
|
||||
inputSchema=self.parameter_schema,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def invoke_tool(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
request = cast(types.CallToolRequest, self.request.root)
|
||||
args = request.params.arguments or {}
|
||||
if self.app.mode in {AppMode.WORKFLOW.value}:
|
||||
args = {"inputs": args}
|
||||
elif self.app.mode in {AppMode.COMPLETION.value}:
|
||||
args = {"query": "", "inputs": args}
|
||||
else:
|
||||
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
|
||||
response = AppGenerateService.generate(
|
||||
self.app,
|
||||
self.end_user,
|
||||
args,
|
||||
InvokeFrom.SERVICE_API,
|
||||
streaming=self.app.mode == AppMode.AGENT_CHAT.value,
|
||||
)
|
||||
answer = ""
|
||||
if isinstance(response, RateLimitGenerator):
|
||||
for item in response.generator:
|
||||
data = item
|
||||
if isinstance(data, str) and data.startswith("data: "):
|
||||
try:
|
||||
json_str = data[6:].strip()
|
||||
parsed_data = json.loads(json_str)
|
||||
if parsed_data.get("event") == "agent_thought":
|
||||
answer += parsed_data.get("thought", "")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(response, Mapping):
|
||||
if self.app.mode in {
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
AppMode.COMPLETION.value,
|
||||
AppMode.CHAT.value,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
}:
|
||||
answer = response["answer"]
|
||||
elif self.app.mode in {AppMode.WORKFLOW.value}:
|
||||
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
# Not support image yet
|
||||
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
|
||||
|
||||
def retrieve_end_user(self):
|
||||
return (
|
||||
db.session.query(EndUser)
|
||||
.where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
|
||||
parameters: dict[str, dict[str, Any]] = {}
|
||||
required = []
|
||||
for item in user_input_form:
|
||||
parameters[item.variable] = {}
|
||||
if item.type in (
|
||||
VariableEntityType.FILE,
|
||||
VariableEntityType.FILE_LIST,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
):
|
||||
continue
|
||||
if item.required:
|
||||
required.append(item.variable)
|
||||
# if the workflow republished, the parameters not changed
|
||||
# we should not raise error here
|
||||
def process_streaming_response(response: RateLimitGenerator) -> str:
|
||||
"""Process streaming response for agent chat mode"""
|
||||
answer = ""
|
||||
for item in response.generator:
|
||||
if isinstance(item, str) and item.startswith("data: "):
|
||||
try:
|
||||
description = self.mcp_server.parameters_dict[item.variable]
|
||||
except KeyError:
|
||||
description = ""
|
||||
parameters[item.variable]["description"] = description
|
||||
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
parameters[item.variable]["type"] = "string"
|
||||
elif item.type == VariableEntityType.SELECT:
|
||||
parameters[item.variable]["type"] = "string"
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "float"
|
||||
return parameters, required
|
||||
json_str = item[6:].strip()
|
||||
parsed_data = json.loads(json_str)
|
||||
if parsed_data.get("event") == "agent_thought":
|
||||
answer += parsed_data.get("thought", "")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return answer
|
||||
|
||||
|
||||
def process_mapping_response(app: App, response: Mapping) -> str:
|
||||
"""Process mapping response based on app mode"""
|
||||
if app.mode in {
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
AppMode.COMPLETION.value,
|
||||
AppMode.CHAT.value,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
}:
|
||||
return response.get("answer", "")
|
||||
elif app.mode == AppMode.WORKFLOW.value:
|
||||
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode: " + str(app.mode))
|
||||
|
||||
|
||||
def convert_input_form_to_parameters(
|
||||
user_input_form: list[VariableEntity],
|
||||
parameters_dict: dict[str, str],
|
||||
) -> tuple[dict[str, dict[str, Any]], list[str]]:
|
||||
"""Convert user input form to parameter schema"""
|
||||
parameters: dict[str, dict[str, Any]] = {}
|
||||
required = []
|
||||
|
||||
for item in user_input_form:
|
||||
if item.type in (
|
||||
VariableEntityType.FILE,
|
||||
VariableEntityType.FILE_LIST,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
):
|
||||
continue
|
||||
parameters[item.variable] = {}
|
||||
if item.required:
|
||||
required.append(item.variable)
|
||||
# if the workflow republished, the parameters not changed
|
||||
# we should not raise error here
|
||||
description = parameters_dict.get(item.variable, "")
|
||||
parameters[item.variable]["description"] = description
|
||||
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
parameters[item.variable]["type"] = "string"
|
||||
elif item.type == VariableEntityType.SELECT:
|
||||
parameters[item.variable]["type"] = "string"
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "float"
|
||||
return parameters, required
|
||||
|
Reference in New Issue
Block a user