feat: add MCP server headers support #22718 (#24760)
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Check i18n Files and Create PR / check-and-update (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Check i18n Files and Create PR / check-and-update (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Novice <novice12185727@gmail.com>
This commit is contained in:
@@ -865,6 +865,7 @@ class ToolProviderMCPApi(Resource):
|
||||
parser.add_argument(
|
||||
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
|
||||
)
|
||||
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
args = parser.parse_args()
|
||||
user = current_user
|
||||
if not is_valid_url(args["server_url"]):
|
||||
@@ -881,6 +882,7 @@ class ToolProviderMCPApi(Resource):
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=args["timeout"],
|
||||
sse_read_timeout=args["sse_read_timeout"],
|
||||
headers=args["headers"],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -898,6 +900,7 @@ class ToolProviderMCPApi(Resource):
|
||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
|
||||
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
|
||||
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
if not is_valid_url(args["server_url"]):
|
||||
if "[__HIDDEN__]" in args["server_url"]:
|
||||
@@ -915,6 +918,7 @@ class ToolProviderMCPApi(Resource):
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=args.get("timeout"),
|
||||
sse_read_timeout=args.get("sse_read_timeout"),
|
||||
headers=args.get("headers"),
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
@@ -951,6 +955,9 @@ class ToolMCPAuthApi(Resource):
|
||||
authed=False,
|
||||
authorization_code=args["authorization_code"],
|
||||
for_list=True,
|
||||
headers=provider.decrypted_headers,
|
||||
timeout=provider.timeout,
|
||||
sse_read_timeout=provider.sse_read_timeout,
|
||||
):
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
mcp_provider=provider,
|
||||
|
@@ -43,6 +43,10 @@ class ToolProviderApiEntity(BaseModel):
|
||||
server_url: Optional[str] = Field(default="", description="The server url of the tool")
|
||||
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
server_identifier: Optional[str] = Field(default="", description="The server identifier of the MCP tool")
|
||||
timeout: Optional[float] = Field(default=30.0, description="The timeout of the MCP tool")
|
||||
sse_read_timeout: Optional[float] = Field(default=300.0, description="The SSE read timeout of the MCP tool")
|
||||
masked_headers: Optional[dict[str, str]] = Field(default=None, description="The masked headers of the MCP tool")
|
||||
original_headers: Optional[dict[str, str]] = Field(default=None, description="The original headers of the MCP tool")
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
@@ -65,6 +69,10 @@ class ToolProviderApiEntity(BaseModel):
|
||||
if self.type == ToolProviderType.MCP:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
optional_fields.update(self.optional_field("timeout", self.timeout))
|
||||
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
|
@@ -94,7 +94,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
provider_id=db_provider.server_identifier or "",
|
||||
tenant_id=db_provider.tenant_id or "",
|
||||
server_url=db_provider.decrypted_server_url,
|
||||
headers={}, # TODO: get headers from db provider
|
||||
headers=db_provider.decrypted_headers or {},
|
||||
timeout=db_provider.timeout,
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
)
|
||||
|
@@ -0,0 +1,27 @@
|
||||
"""add_headers_to_mcp_provider
|
||||
|
||||
Revision ID: c20211f18133
|
||||
Revises: 8d289573e1da
|
||||
Create Date: 2025-08-29 10:07:54.163626
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'c20211f18133'
|
||||
down_revision = 'b95962a3885c'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Add encrypted_headers column to tool_mcp_providers table
|
||||
op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', sa.Text(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Remove encrypted_headers column from tool_mcp_providers table
|
||||
op.drop_column('tool_mcp_providers', 'encrypted_headers')
|
@@ -280,6 +280,8 @@ class MCPToolProvider(Base):
|
||||
)
|
||||
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
|
||||
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))
|
||||
# encrypted headers for MCP server requests
|
||||
encrypted_headers: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
|
||||
def load_user(self) -> Account | None:
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
@@ -310,6 +312,62 @@ class MCPToolProvider(Base):
|
||||
def decrypted_server_url(self) -> str:
|
||||
return encrypter.decrypt_token(self.tenant_id, self.server_url)
|
||||
|
||||
@property
|
||||
def decrypted_headers(self) -> dict[str, Any]:
|
||||
"""Get decrypted headers for MCP server requests."""
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
try:
|
||||
if not self.encrypted_headers:
|
||||
return {}
|
||||
|
||||
headers_data = json.loads(self.encrypted_headers)
|
||||
|
||||
# Create dynamic config for all headers as SECRET_INPUT
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
result = encrypter_instance.decrypt(headers_data)
|
||||
return result
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def masked_headers(self) -> dict[str, Any]:
|
||||
"""Get masked headers for frontend display."""
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
try:
|
||||
if not self.encrypted_headers:
|
||||
return {}
|
||||
|
||||
headers_data = json.loads(self.encrypted_headers)
|
||||
|
||||
# Create dynamic config for all headers as SECRET_INPUT
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# First decrypt, then mask
|
||||
decrypted_headers = encrypter_instance.decrypt(headers_data)
|
||||
result = encrypter_instance.mask_tool_credentials(decrypted_headers)
|
||||
return result
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def masked_server_url(self) -> str:
|
||||
def mask_url(url: str, mask_char: str = "*") -> str:
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -27,6 +27,36 @@ class MCPToolManageService:
|
||||
Service class for managing mcp tools.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
|
||||
"""
|
||||
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
|
||||
|
||||
Args:
|
||||
headers: Dictionary of headers to encrypt
|
||||
tenant_id: Tenant ID for encryption
|
||||
|
||||
Returns:
|
||||
Dictionary with all headers encrypted
|
||||
"""
|
||||
if not headers:
|
||||
return {}
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
# Create dynamic config for all headers as SECRET_INPUT
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return cast(dict[str, str], encrypter_instance.encrypt(headers))
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
|
||||
res = (
|
||||
@@ -61,6 +91,7 @@ class MCPToolManageService:
|
||||
server_identifier: str,
|
||||
timeout: float,
|
||||
sse_read_timeout: float,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> ToolProviderApiEntity:
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
existing_provider = (
|
||||
@@ -83,6 +114,12 @@ class MCPToolManageService:
|
||||
if existing_provider.server_identifier == server_identifier:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
# Encrypt headers
|
||||
encrypted_headers = None
|
||||
if headers:
|
||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
|
||||
encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||
|
||||
mcp_tool = MCPToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
@@ -95,6 +132,7 @@ class MCPToolManageService:
|
||||
server_identifier=server_identifier,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
encrypted_headers=encrypted_headers,
|
||||
)
|
||||
db.session.add(mcp_tool)
|
||||
db.session.commit()
|
||||
@@ -118,9 +156,21 @@ class MCPToolManageService:
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
server_url = mcp_provider.decrypted_server_url
|
||||
authed = mcp_provider.authed
|
||||
headers = mcp_provider.decrypted_headers
|
||||
timeout = mcp_provider.timeout
|
||||
sse_read_timeout = mcp_provider.sse_read_timeout
|
||||
|
||||
try:
|
||||
with MCPClient(server_url, provider_id, tenant_id, authed=authed, for_list=True) as mcp_client:
|
||||
with MCPClient(
|
||||
server_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
authed=authed,
|
||||
for_list=True,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
except MCPAuthError:
|
||||
raise ValueError("Please auth the tool first")
|
||||
@@ -172,6 +222,7 @@ class MCPToolManageService:
|
||||
server_identifier: str,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
|
||||
@@ -207,6 +258,13 @@ class MCPToolManageService:
|
||||
mcp_provider.timeout = timeout
|
||||
if sse_read_timeout is not None:
|
||||
mcp_provider.sse_read_timeout = sse_read_timeout
|
||||
if headers is not None:
|
||||
# Encrypt headers
|
||||
if headers:
|
||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
|
||||
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||
else:
|
||||
mcp_provider.encrypted_headers = None
|
||||
db.session.commit()
|
||||
except IntegrityError as e:
|
||||
db.session.rollback()
|
||||
@@ -242,6 +300,12 @@ class MCPToolManageService:
|
||||
|
||||
@classmethod
|
||||
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
|
||||
# Get the existing provider to access headers and timeout settings
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
headers = mcp_provider.decrypted_headers
|
||||
timeout = mcp_provider.timeout
|
||||
sse_read_timeout = mcp_provider.sse_read_timeout
|
||||
|
||||
try:
|
||||
with MCPClient(
|
||||
server_url,
|
||||
@@ -249,6 +313,9 @@ class MCPToolManageService:
|
||||
tenant_id,
|
||||
authed=False,
|
||||
for_list=True,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
return {
|
||||
|
@@ -237,6 +237,10 @@ class ToolTransformService:
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
server_identifier=db_provider.server_identifier,
|
||||
timeout=db_provider.timeout,
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
masked_headers=db_provider.masked_headers,
|
||||
original_headers=db_provider.decrypted_headers,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@@ -706,7 +706,14 @@ class TestMCPToolManageService:
|
||||
|
||||
# Verify mock interactions
|
||||
mock_mcp_client.assert_called_once_with(
|
||||
"https://example.com/mcp", mcp_provider.id, tenant.id, authed=False, for_list=True
|
||||
"https://example.com/mcp",
|
||||
mcp_provider.id,
|
||||
tenant.id,
|
||||
authed=False,
|
||||
for_list=True,
|
||||
headers={},
|
||||
timeout=30.0,
|
||||
sse_read_timeout=300.0,
|
||||
)
|
||||
|
||||
def test_list_mcp_tool_from_remote_server_auth_error(
|
||||
@@ -1181,6 +1188,11 @@ class TestMCPToolManageService:
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create MCP provider first
|
||||
mcp_provider = self._create_test_mcp_provider(
|
||||
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
|
||||
)
|
||||
|
||||
# Mock MCPClient and its context manager
|
||||
mock_tools = [
|
||||
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_1", "description": "Test tool 1"}})(),
|
||||
@@ -1194,7 +1206,7 @@ class TestMCPToolManageService:
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = MCPToolManageService._re_connect_mcp_provider(
|
||||
"https://example.com/mcp", "test_provider_id", tenant.id
|
||||
"https://example.com/mcp", mcp_provider.id, tenant.id
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
@@ -1213,7 +1225,14 @@ class TestMCPToolManageService:
|
||||
|
||||
# Verify mock interactions
|
||||
mock_mcp_client.assert_called_once_with(
|
||||
"https://example.com/mcp", "test_provider_id", tenant.id, authed=False, for_list=True
|
||||
"https://example.com/mcp",
|
||||
mcp_provider.id,
|
||||
tenant.id,
|
||||
authed=False,
|
||||
for_list=True,
|
||||
headers={},
|
||||
timeout=30.0,
|
||||
sse_read_timeout=300.0,
|
||||
)
|
||||
|
||||
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
@@ -1231,6 +1250,11 @@ class TestMCPToolManageService:
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create MCP provider first
|
||||
mcp_provider = self._create_test_mcp_provider(
|
||||
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
|
||||
)
|
||||
|
||||
# Mock MCPClient to raise authentication error
|
||||
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
|
||||
from core.mcp.error import MCPAuthError
|
||||
@@ -1240,7 +1264,7 @@ class TestMCPToolManageService:
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = MCPToolManageService._re_connect_mcp_provider(
|
||||
"https://example.com/mcp", "test_provider_id", tenant.id
|
||||
"https://example.com/mcp", mcp_provider.id, tenant.id
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
@@ -1265,6 +1289,11 @@ class TestMCPToolManageService:
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create MCP provider first
|
||||
mcp_provider = self._create_test_mcp_provider(
|
||||
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
|
||||
)
|
||||
|
||||
# Mock MCPClient to raise connection error
|
||||
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
|
||||
from core.mcp.error import MCPError
|
||||
@@ -1274,4 +1303,4 @@ class TestMCPToolManageService:
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
|
||||
MCPToolManageService._re_connect_mcp_provider("https://example.com/mcp", "test_provider_id", tenant.id)
|
||||
MCPToolManageService._re_connect_mcp_provider("https://example.com/mcp", mcp_provider.id, tenant.id)
|
||||
|
Reference in New Issue
Block a user