Support OAuth Integration for Plugin Tools (#22550)

Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
Maries
2025-07-17 17:18:44 +08:00
committed by GitHub
parent 965e952336
commit a4ef900916
89 changed files with 5516 additions and 875 deletions

View File

@@ -6,6 +6,7 @@ on:
- "main"
- "deploy/dev"
- "deploy/enterprise"
- "build/**"
tags:
- "*"

View File

@@ -5,17 +5,17 @@
SECRET_KEY=
# Console API base URL
CONSOLE_API_URL=http://127.0.0.1:5001
CONSOLE_WEB_URL=http://127.0.0.1:3000
CONSOLE_API_URL=http://localhost:5001
CONSOLE_WEB_URL=http://localhost:3000
# Service API base URL
SERVICE_API_URL=http://127.0.0.1:5001
SERVICE_API_URL=http://localhost:5001
# Web APP base URL
APP_WEB_URL=http://127.0.0.1:3000
APP_WEB_URL=http://localhost:3000
# Files URL
FILES_URL=http://127.0.0.1:5001
FILES_URL=http://localhost:5001
# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
# Set this to the internal Docker service URL for proper plugin file access.
@@ -138,8 +138,8 @@ SUPABASE_API_KEY=your-access-key
SUPABASE_URL=your-server-url
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone

View File

@@ -2,19 +2,22 @@ import base64
import json
import logging
import secrets
from typing import Optional
from typing import Any, Optional
import click
from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from configs import dify_config
from constants.languages import languages
from core.plugin.entities.plugin import ToolProviderID
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.models.document import Document
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -27,6 +30,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.provider import Provider, ProviderModel
from models.tools import ToolOAuthSystemClient
from services.account_service import AccountService, RegisterService, TenantService
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
from services.plugin.data_migration import PluginDataMigration
@@ -1155,3 +1159,49 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green"))
else:
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
def setup_system_tool_oauth_client(provider, client_params):
"""
Setup system tool oauth client
"""
provider_id = ToolProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = (
db.session.query(ToolOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
oauth_client = ToolOAuthSystemClient(
provider=provider_name,
plugin_id=plugin_id,
encrypted_oauth_params=oauth_client_params,
)
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))

View File

@@ -1,6 +1,7 @@
from configs import dify_config
HIDDEN_VALUE = "[__HIDDEN__]"
UNKNOWN_VALUE = "[__UNKNOWN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3

View File

@@ -1,23 +1,32 @@
import io
from urllib.parse import urlparse
from flask import redirect, send_file
from flask import make_response, redirect, request, send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
from flask_restful import (
Resource,
reqparse,
)
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
)
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_mange_service import MCPToolManageService
@@ -89,7 +98,7 @@ class ToolBuiltinProviderInfoApi(Resource):
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
class ToolBuiltinProviderDeleteApi(Resource):
@@ -98,17 +107,47 @@ class ToolBuiltinProviderDeleteApi(Resource):
@account_initialization_required
def post(self, provider):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = req.parse_args()
return BuiltinToolManageService.delete_builtin_tool_provider(
tenant_id,
provider,
args["credential_id"],
)
class ToolBuiltinProviderAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return BuiltinToolManageService.delete_builtin_tool_provider(
user_id,
tenant_id,
provider,
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if args["type"] not in CredentialType.values():
raise ValueError(f"Invalid credential type: {args['type']}")
return BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
provider=provider,
credentials=args["credentials"],
name=args["name"],
api_type=CredentialType.of(args["type"]),
)
@@ -126,19 +165,20 @@ class ToolBuiltinProviderUpdateApi(Resource):
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
with Session(db.engine) as session:
result = BuiltinToolManageService.update_builtin_tool_provider(
session=session,
user_id=user_id,
tenant_id=tenant_id,
provider_name=provider,
credentials=args["credentials"],
)
session.commit()
result = BuiltinToolManageService.update_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
provider=provider,
credential_id=args["credential_id"],
credentials=args.get("credentials", None),
name=args.get("name", ""),
)
return result
@@ -149,9 +189,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
def get(self, provider):
tenant_id = current_user.current_tenant_id
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
tenant_id=tenant_id,
provider_name=provider,
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credentials(
tenant_id=tenant_id,
provider_name=provider,
)
)
@@ -344,12 +386,15 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider, credential_type):
user = current_user
tenant_id = user.current_tenant_id
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
return jsonable_encoder(
BuiltinToolManageService.list_builtin_provider_credentials_schema(
provider, CredentialType.of(credential_type), tenant_id
)
)
class ToolApiProviderSchemaApi(Resource):
@@ -586,15 +631,12 @@ class ToolApiListApi(Resource):
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
provider.to_dict()
for provider in ApiToolManageService.list_api_tools(
user_id,
tenant_id,
)
]
@@ -631,6 +673,179 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
class ToolPluginOAuthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
# todo check permission
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
)
response = make_response(jsonable_encoder(authorization_url_response))
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response
class ToolOAuthCallback(Resource):
@setup_required
def get(self, provider):
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
oauth_handler = OAuthHandler()
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
).credentials
if not credentials:
raise Exception("the plugin credentials failed")
# add credentials to database
BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
provider=provider,
credentials=dict(credentials),
api_type=CredentialType.OAUTH2,
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
class ToolBuiltinProviderSetDefaultApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
)
class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider=provider,
client_params=args.get("client_params", {}),
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
)
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
)
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
)
class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
tenant_id=current_user.current_tenant_id, provider_name=provider
)
)
class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
tenant_id=tenant_id,
provider=provider,
)
)
class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@@ -794,17 +1009,33 @@ class ToolMCPCallbackApi(Resource):
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# tool oauth
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin/<path:provider>/add")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
)
api.add_resource(
ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credential/info"
)
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
"/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>",
)
api.add_resource(
ToolBuiltinProviderGetOauthClientSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema",
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")

View File

@@ -175,6 +175,7 @@ class PluginInvokeToolApi(Resource):
provider=payload.provider,
tool_name=payload.tool,
tool_parameters=payload.tool_parameters,
credential_id=payload.credential_id,
),
)

View File

@@ -16,6 +16,7 @@ class AgentToolEntity(BaseModel):
tool_name: str
tool_parameters: dict[str, Any] = Field(default_factory=dict)
plugin_unique_identifier: str | None = None
credential_id: str | None = None
class AgentPromptEntity(BaseModel):

View File

@@ -4,6 +4,7 @@ from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyParameter
from core.plugin.entities.request import InvokeCredentials
class BaseAgentStrategy(ABC):
@@ -18,11 +19,12 @@ class BaseAgentStrategy(ABC):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
"""
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
yield from self._invoke(params, user_id, conversation_id, app_id, message_id, credentials)
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
"""
@@ -38,5 +40,6 @@ class BaseAgentStrategy(ABC):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
pass

View File

@@ -4,6 +4,7 @@ from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
from core.agent.strategy.base import BaseAgentStrategy
from core.plugin.entities.request import InvokeCredentials, PluginInvokeContext
from core.plugin.impl.agent import PluginAgentClient
from core.plugin.utils.converter import convert_parameters_to_plugin_format
@@ -40,6 +41,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
@@ -58,4 +60,5 @@ class PluginAgentStrategy(BaseAgentStrategy):
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
context=PluginInvokeContext(credentials=credentials or InvokeCredentials()),
)

View File

@@ -39,6 +39,7 @@ class AgentConfigManager:
"provider_id": tool["provider_id"],
"tool_name": tool["tool_name"],
"tool_parameters": tool.get("tool_parameters", {}),
"credential_id": tool.get("credential_id", None),
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))

View File

@@ -0,0 +1,84 @@
import json
from abc import ABC, abstractmethod
from json import JSONDecodeError
from typing import Any, Optional
from extensions.ext_redis import redis_client
class ProviderCredentialsCache(ABC):
"""Base class for provider credentials cache"""
def __init__(self, **kwargs):
self.cache_key = self._generate_cache_key(**kwargs)
@abstractmethod
def _generate_cache_key(self, **kwargs) -> str:
"""Generate cache key based on subclass implementation"""
pass
def get(self) -> Optional[dict]:
"""Get cached provider credentials"""
cached_credentials = redis_client.get(self.cache_key)
if cached_credentials:
try:
cached_credentials = cached_credentials.decode("utf-8")
return dict(json.loads(cached_credentials))
except JSONDecodeError:
return None
return None
def set(self, config: dict[str, Any]) -> None:
"""Cache provider credentials"""
redis_client.setex(self.cache_key, 86400, json.dumps(config))
def delete(self) -> None:
"""Delete cached provider credentials"""
redis_client.delete(self.cache_key)
class SingletonProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for tool single provider credentials"""
def __init__(self, tenant_id: str, provider_type: str, provider_identity: str):
super().__init__(
tenant_id=tenant_id,
provider_type=provider_type,
provider_identity=provider_identity,
)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_type = kwargs["provider_type"]
identity_name = kwargs["provider_identity"]
identity_id = f"{provider_type}.{identity_name}"
return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
class ToolProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for tool provider credentials"""
def __init__(self, tenant_id: str, provider: str, credential_id: str):
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider = kwargs["provider"]
credential_id = kwargs["credential_id"]
return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
class NoOpProviderCredentialCache:
"""No-op provider credential cache"""
def get(self) -> Optional[dict]:
"""Get cached provider credentials"""
return None
def set(self, config: dict[str, Any]) -> None:
"""Cache provider credentials"""
pass
def delete(self) -> None:
"""Delete cached provider credentials"""
pass

View File

@@ -1,51 +0,0 @@
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolProviderCredentialsCacheType(Enum):
PROVIDER = "tool_provider"
ENDPOINT = "endpoint"
class ToolProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode("utf-8")
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
return dict(cached_provider_credentials)
else:
return None
def set(self, credentials: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)

View File

@@ -1,16 +1,20 @@
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.plugin.entities.request import RequestInvokeEncrypt
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_provider_encrypter
from models.account import Tenant
class PluginEncrypter:
@classmethod
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
encrypter = ProviderConfigEncrypter(
encrypter, cache = create_provider_encrypter(
tenant_id=tenant.id,
config=payload.config,
provider_type=payload.namespace,
provider_identity=payload.identity,
cache=SingletonProviderCredentialsCache(
tenant_id=tenant.id,
provider_type=payload.namespace,
provider_identity=payload.identity,
),
)
if payload.opt == "encrypt":
@@ -22,7 +26,7 @@ class PluginEncrypter:
"data": encrypter.decrypt(payload.data),
}
elif payload.opt == "clear":
encrypter.delete_tool_credentials_cache()
cache.delete()
return {
"data": {},
}

View File

@@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any
from typing import Any, Optional
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
@@ -23,6 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
credential_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke tool
@@ -30,7 +31,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
# get tool runtime
try:
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
tool_type, tenant_id, provider, tool_name, tool_parameters
tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
)
response = ToolEngine.generic_invoke(
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1

View File

@@ -27,6 +27,20 @@ from core.workflow.nodes.question_classifier.entities import (
)
class InvokeCredentials(BaseModel):
tool_credentials: dict[str, str] = Field(
default_factory=dict,
description="Map of tool provider to credential id, used to store the credential id for the tool provider.",
)
class PluginInvokeContext(BaseModel):
credentials: Optional[InvokeCredentials] = Field(
default_factory=InvokeCredentials,
description="Credentials context for the plugin invocation or backward invocation.",
)
class RequestInvokeTool(BaseModel):
"""
Request to invoke a tool
@@ -36,6 +50,7 @@ class RequestInvokeTool(BaseModel):
provider: str
tool: str
tool_parameters: dict
credential_id: Optional[str] = None
class BaseRequestInvokeModel(BaseModel):

View File

@@ -6,6 +6,7 @@ from core.plugin.entities.plugin import GenericProviderID
from core.plugin.entities.plugin_daemon import (
PluginAgentProviderEntity,
)
from core.plugin.entities.request import PluginInvokeContext
from core.plugin.impl.base import BasePluginClient
@@ -83,6 +84,7 @@ class PluginAgentClient(BasePluginClient):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
context: Optional[PluginInvokeContext] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
@@ -99,6 +101,7 @@ class PluginAgentClient(BasePluginClient):
"conversation_id": conversation_id,
"app_id": app_id,
"message_id": message_id,
"context": context.model_dump() if context else {},
"data": {
"agent_strategy_provider": agent_provider_id.provider_name,
"agent_strategy": agent_strategy,

View File

@@ -15,27 +15,32 @@ class OAuthHandler(BasePluginClient):
user_id: str,
plugin_id: str,
provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any],
) -> PluginOAuthAuthorizationUrlResponse:
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
PluginOAuthAuthorizationUrlResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"system_credentials": system_credentials,
try:
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
PluginOAuthAuthorizationUrlResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials,
},
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
except Exception as e:
raise ValueError(f"Error getting authorization URL: {e}")
def get_credentials(
self,
@@ -43,6 +48,7 @@ class OAuthHandler(BasePluginClient):
user_id: str,
plugin_id: str,
provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any],
request: Request,
) -> PluginOAuthCredentialsResponse:
@@ -50,30 +56,33 @@ class OAuthHandler(BasePluginClient):
Get credentials from the given request.
"""
# encode request to raw http request
raw_request_bytes = self._convert_request_to_raw_data(request)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
PluginOAuthCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"system_credentials": system_credentials,
# for json serialization
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
try:
# encode request to raw http request
raw_request_bytes = self._convert_request_to_raw_data(request)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
PluginOAuthCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials,
# for json serialization
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
},
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
except Exception as e:
raise ValueError(f"Error getting credentials: {e}")
def _convert_request_to_raw_data(self, request: Request) -> bytes:
"""

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
class PluginToolManager(BasePluginClient):
@@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient):
tool_provider: str,
tool_name: str,
credentials: dict[str, Any],
credential_type: CredentialType,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
@@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient):
"provider": tool_provider_id.provider_name,
"tool": tool_name,
"credentials": credentials,
"credential_type": credential_type,
"tool_parameters": tool_parameters,
},
},

View File

@@ -4,7 +4,7 @@ from openai import BaseModel
from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolInvokeFrom
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
class ToolRuntime(BaseModel):
@@ -17,6 +17,7 @@ class ToolRuntime(BaseModel):
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict)
credential_type: CredentialType = Field(default=CredentialType.API_KEY)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)

View File

@@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
from core.tools.entities.tool_entities import (
CredentialType,
OAuthSchema,
ToolEntity,
ToolProviderEntity,
ToolProviderType,
)
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolProviderNotFoundError,
@@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
credentials_schema.append(credential_dict)
oauth_schema = None
if provider_yaml.get("oauth_schema", None) is not None:
oauth_schema = OAuthSchema(
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
)
super().__init__(
entity=ToolProviderEntity(
identity=provider_yaml["identity"],
credentials_schema=credentials_schema,
oauth_schema=oauth_schema,
),
)
@@ -97,10 +111,39 @@ class BuiltinToolProviderController(ToolProviderController):
:return: the credentials schema
"""
if not self.entity.credentials_schema:
return []
return self.get_credentials_schema_by_type(CredentialType.API_KEY.value)
return self.entity.credentials_schema.copy()
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:param credential_type: the type of the credential
:return: the credentials schema of the provider
"""
if credential_type == CredentialType.OAUTH2.value:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == CredentialType.API_KEY.value:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
raise ValueError(f"Invalid credential type: {credential_type}")
def get_oauth_client_schema(self) -> list[ProviderConfig]:
"""
returns the oauth client schema of the provider
:return: the oauth client schema
"""
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
def get_supported_credential_types(self) -> list[str]:
"""
returns the credential support type of the provider
"""
types = []
if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0:
types.append(CredentialType.API_KEY.value)
if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0:
types.append(CredentialType.OAUTH2.value)
return types
def get_tools(self) -> list[BuiltinTool]:
"""
@@ -123,7 +166,11 @@ class BuiltinToolProviderController(ToolProviderController):
:return: whether the provider needs credentials
"""
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
return (
self.entity.credentials_schema is not None
and len(self.entity.credentials_schema) != 0
or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0)
)
@property
def provider_type(self) -> ToolProviderType:

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, field_validator
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.tool_entities import CredentialType, ToolProviderType
class ToolApiEntity(BaseModel):
@@ -87,3 +87,22 @@ class ToolProviderApiEntity(BaseModel):
def optional_field(self, key: str, value: Any) -> dict:
"""Return dict with key-value if value is truthy, empty dict otherwise."""
return {key: value} if value else {}
class ToolProviderCredentialApiEntity(BaseModel):
id: str = Field(description="The unique id of the credential")
name: str = Field(description="The name of the credential")
provider: str = Field(description="The provider of the credential")
credential_type: CredentialType = Field(description="The type of the credential")
is_default: bool = Field(
default=False, description="Whether the credential is the default credential for the provider in the workspace"
)
credentials: dict = Field(description="The credentials of the provider")
class ToolProviderCredentialInfoApiEntity(BaseModel):
supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
is_oauth_custom_client_enabled: bool = Field(
default=False, description="Whether the OAuth custom client is enabled for the provider"
)
credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")

View File

@@ -370,10 +370,18 @@ class ToolEntity(BaseModel):
return v or []
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
credentials_schema: list[ProviderConfig] = Field(
default_factory=list, description="The schema of the OAuth credentials"
)
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: Optional[str] = None
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = None
class ToolProviderEntityWithPlugin(ToolProviderEntity):
@@ -453,6 +461,7 @@ class ToolSelector(BaseModel):
options: Optional[list[PluginParameterOption]] = None
provider_id: str = Field(..., description="The id of the provider")
credential_id: Optional[str] = Field(default=None, description="The id of the credential")
tool_name: str = Field(..., description="The name of the tool")
tool_description: str = Field(..., description="The description of the tool")
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
@@ -460,3 +469,36 @@ class ToolSelector(BaseModel):
def to_plugin_parameter(self) -> dict[str, Any]:
return self.model_dump()
class CredentialType(enum.StrEnum):
API_KEY = "api-key"
OAUTH2 = "oauth2"
def get_name(self):
if self == CredentialType.API_KEY:
return "API KEY"
elif self == CredentialType.OAUTH2:
return "AUTH"
else:
return self.value.replace("-", " ").upper()
def is_editable(self):
return self == CredentialType.API_KEY
def is_validate_allowed(self):
return self == CredentialType.API_KEY
@classmethod
def values(cls):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
type_name = credential_type.lower()
if type_name == "api-key":
return cls.API_KEY
elif type_name == "oauth2":
return cls.OAUTH2
else:
raise ValueError(f"Invalid credential type: {credential_type}")

View File

@@ -44,6 +44,7 @@ class PluginTool(Tool):
tool_provider=self.entity.identity.provider,
tool_name=self.entity.identity.name,
credentials=self.runtime.credentials,
credential_type=self.runtime.credential_type,
tool_parameters=tool_parameters,
conversation_id=conversation_id,
app_id=app_id,

View File

@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from yarl import URL
import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
@@ -17,6 +18,7 @@ from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.entities.variable_pool import VariablePool
from services.tools.mcp_tools_mange_service import MCPToolManageService
@@ -24,7 +26,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -41,16 +42,17 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
CredentialType,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
)
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ProviderConfigEncrypter,
ToolParameterConfigurationManager,
)
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
@@ -68,8 +70,11 @@ class ToolManager:
@classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
"""
get the hardcoded provider
"""
if len(cls._hardcoded_providers) == 0:
# init the builtin providers
cls.load_hardcoded_providers_cache()
@@ -113,7 +118,12 @@ class ToolManager:
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(Lock())
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
with contexts.plugin_tool_providers_lock.get():
# double check
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
@@ -131,25 +141,7 @@ class ToolManager:
)
plugin_tool_providers[provider] = controller
return controller
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
"""
get the builtin tool
:param provider: the name of the provider
:param tool_name: the name of the tool
:param tenant_id: the id of the tenant
:return: the provider, the tool
"""
provider_controller = cls.get_builtin_provider(provider, tenant_id)
tool = provider_controller.get_tool(tool_name)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool
return controller
@classmethod
def get_tool_runtime(
@@ -160,6 +152,7 @@ class ToolManager:
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: Optional[str] = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
"""
get the tool runtime
@@ -170,6 +163,7 @@ class ToolManager:
:param tenant_id: the tenant id
:param invoke_from: invoke from
:param tool_invoke_from: the tool invoke from
:param credential_id: the credential id
:return: the tool
"""
@@ -193,49 +187,70 @@ class ToolManager:
)
),
)
builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id)
# get credentials
builtin_provider: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.first()
)
# get specific credentials
if is_valid_uuid(credential_id):
try:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
except Exception as e:
builtin_provider = None
logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True)
# if the provider has been deleted, raise an error
if builtin_provider is None:
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
# fallback to the default provider
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
# use the default provider
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
# decrypt the credentials
credentials = builtin_provider.credentials
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
],
cache=ToolProviderCredentialsCache(
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
)
decrypted_credentials = tool_configuration.decrypt(credentials)
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=decrypted_credentials,
credentials=encrypter.decrypt(builtin_provider.credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
@@ -245,22 +260,16 @@ class ToolManager:
elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
# decrypt the credentials
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
provider_type=api_provider.provider_type.value,
provider_identity=api_provider.entity.identity.name,
controller=api_provider,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
return cast(
ApiTool,
api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=decrypted_credentials,
credentials=encrypter.decrypt(credentials),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
@@ -320,6 +329,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT,
credential_id=agent_tool.credential_id,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
@@ -362,6 +372,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
credential_id=workflow_tool.credential_id,
)
parameters = tool_runtime.get_merged_runtime_parameters()
@@ -391,6 +402,7 @@ class ToolManager:
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
credential_id: Optional[str] = None,
) -> Tool:
"""
get tool runtime from plugin
@@ -402,6 +414,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.PLUGIN,
credential_id=credential_id,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
@@ -551,6 +564,22 @@ class ToolManager:
return cls._builtin_tools_labels[tool_name]
@classmethod
def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
"""
list all the builtin providers
"""
# according to multi credentials, select the one with is_default=True first, then created_at oldest
# for compatibility with old version
sql = """
SELECT DISTINCT ON (tenant_id, provider) id
FROM tool_builtin_providers
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
@classmethod
def list_providers_from_api(
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
@@ -565,21 +594,13 @@ class ToolManager:
with db.session.no_autoflush:
if "builtin" in filters:
# get builtin providers
builtin_providers = cls.list_builtin_providers(tenant_id)
# get db builtin providers
db_builtin_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
)
# rewrite db_builtin_providers
for db_provider in db_builtin_providers:
tool_provider_id = str(ToolProviderID(db_provider.provider))
db_provider.provider = tool_provider_id
def find_db_builtin_provider(provider):
return next((x for x in db_builtin_providers if x.provider == provider), None)
# key: provider name, value: provider
db_builtin_providers = {
str(ToolProviderID(provider.provider)): provider
for provider in cls.list_default_builtin_providers(tenant_id)
}
# append builtin providers
for provider in builtin_providers:
@@ -591,10 +612,9 @@ class ToolManager:
name_func=lambda x: x.identity.name,
):
continue
user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider,
db_provider=find_db_builtin_provider(provider.entity.identity.name),
db_provider=db_builtin_providers.get(provider.entity.identity.name),
decrypt_credentials=False,
)
@@ -604,7 +624,6 @@ class ToolManager:
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
# get db api providers
if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
@@ -764,15 +783,12 @@ class ToolManager:
auth_type,
)
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
controller=controller,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
try:
icon = json.loads(provider_obj.icon)

View File

@@ -1,12 +1,8 @@
from copy import deepcopy
from typing import Any
from pydantic import BaseModel
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
@@ -14,110 +10,6 @@ from core.tools.entities.tool_entities import (
)
class ProviderConfigEncrypter(BaseModel):
tenant_id: str
config: list[BasicProviderConfig]
provider_type: str
provider_identity: str
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
if use_cache:
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass
if use_cache:
cache.set(data)
return data
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cache.delete()
class ToolParameterConfigurationManager:
"""
Tool parameter configuration manager

View File

@@ -0,0 +1,142 @@
from copy import deepcopy
from typing import Any, Optional, Protocol
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.tools.__base.tool_provider import ToolProviderController
class ProviderConfigCache(Protocol):
"""
Interface for provider configuration cache operations
"""
def get(self) -> Optional[dict]:
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]) -> None:
"""Cache provider configuration"""
...
def delete(self) -> None:
"""Delete cached provider configuration"""
...
class ProviderConfigEncrypter:
tenant_id: str
config: list[BasicProviderConfig]
provider_config_cache: ProviderConfigCache
def __init__(
self,
tenant_id: str,
config: list[BasicProviderConfig],
provider_config_cache: ProviderConfigCache,
):
self.tenant_id = tenant_id
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cached_credentials = self.provider_config_cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass
self.provider_config_cache.set(data)
return data
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
cache = SingletonProviderCredentialsCache(
tenant_id=tenant_id,
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
)
encrypt = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_config_cache=cache,
)
return encrypt, cache

View File

@@ -0,0 +1,187 @@
import base64
import hashlib
import logging
from collections.abc import Mapping
from typing import Any, Optional
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from pydantic import TypeAdapter
from configs import dify_config
logger = logging.getLogger(__name__)
class OAuthEncryptionError(Exception):
"""OAuth encryption/decryption specific error"""
pass
class SystemOAuthEncrypter:
"""
A simple OAuth parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt OAuth parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: Optional[str] = None):
"""
Initialize the OAuth encrypter.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Raises:
ValueError: If SECRET_KEY is not configured or empty
"""
secret_key = secret_key or dify_config.SECRET_KEY or ""
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters.
Args:
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns:
Base64-encoded encrypted string
Raises:
OAuthEncryptionError: If encryption fails
ValueError: If oauth_params is invalid
"""
try:
# Generate random IV (16 bytes)
iv = get_random_bytes(16)
# Create AES cipher (CBC mode)
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
combined = iv + encrypted_data
# Return base64 encoded string
return base64.b64encode(combined).decode()
except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Raises:
OAuthEncryptionError: If decryption fails
ValueError: If encrypted_data is invalid
"""
if not isinstance(encrypted_data, str):
raise ValueError("encrypted_data must be a string")
if not encrypted_data:
raise ValueError("encrypted_data cannot be empty")
try:
# Base64 decode
combined = base64.b64decode(encrypted_data)
# Check minimum length (IV + at least one AES block)
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
raise ValueError("Invalid encrypted data format")
# Separate IV and encrypted data
iv = combined[:16]
encrypted_data_bytes = combined[16:]
# Create AES cipher
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Decrypt data
decrypted_data = cipher.decrypt(encrypted_data_bytes)
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(oauth_params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params
except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter:
"""
Create an OAuth encrypter instance.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns:
SystemOAuthEncrypter instance
"""
return SystemOAuthEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility)
_oauth_encrypter: Optional[SystemOAuthEncrypter] = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
"""
Get the global OAuth encrypter instance.
Returns:
SystemOAuthEncrypter instance
"""
global _oauth_encrypter
if _oauth_encrypter is None:
_oauth_encrypter = SystemOAuthEncrypter()
return _oauth_encrypter
# Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters using the global encrypter.
Args:
oauth_params: OAuth parameters dictionary
Returns:
Base64-encoded encrypted string
"""
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters using the global encrypter.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
"""
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)

View File

@@ -1,7 +1,9 @@
import uuid
def is_valid_uuid(uuid_str: str) -> bool:
def is_valid_uuid(uuid_str: str | None) -> bool:
if uuid_str is None or len(uuid_str) == 0:
return False
try:
uuid.UUID(uuid_str)
return True

View File

@@ -4,6 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from packaging.version import Version
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -13,10 +14,16 @@ from core.agent.strategy.plugin import PluginAgentStrategy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.entities.tool_entities import (
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.tool_manager import ToolManager
from core.variables.segments import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
@@ -84,6 +91,7 @@ class AgentNode(ToolNode):
for_log=True,
strategy=strategy,
)
credentials = self._generate_credentials(parameters=parameters)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
@@ -94,6 +102,7 @@ class AgentNode(ToolNode):
user_id=self.user_id,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
)
except Exception as e:
yield RunCompletedEvent(
@@ -246,6 +255,7 @@ class AgentNode(ToolNode):
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
)
extra = tool.get("extra", {})
@@ -276,6 +286,7 @@ class AgentNode(ToolNode):
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
@@ -305,6 +316,27 @@ class AgentNode(ToolNode):
return result
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> InvokeCredentials:
"""
Generate credentials based on the given agent parameters.
"""
credentials = InvokeCredentials()
# generate credentials for tools selector
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
if tool.get("credential_id"):
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
except ValidationError:
continue
return credentials
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@@ -14,6 +14,7 @@ class ToolEntity(BaseModel):
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, Any]
credential_id: str | None = None
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")

View File

@@ -20,6 +20,7 @@ def handle(sender, **kwargs):
provider_id=tool_entity.provider_id,
tool_name=tool_entity.tool_name,
tenant_id=app.tenant_id,
credential_id=tool_entity.credential_id,
)
manager = ToolParameterConfigurationManager(
tenant_id=app.tenant_id,

View File

@@ -18,6 +18,7 @@ def init_app(app: DifyApp):
reset_email,
reset_encrypt_key_pair,
reset_password,
setup_system_tool_oauth_client,
upgrade_db,
vdb_migrate,
)
@@ -40,6 +41,7 @@ def init_app(app: DifyApp):
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
setup_system_tool_oauth_client,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@@ -0,0 +1,41 @@
"""empty message
Revision ID: 16081485540c
Revises: d28f2004b072
Create Date: 2025-05-15 16:35:39.113777
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '16081485540c'
down_revision = '2adcbe1f5dfb'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tenant_plugin_auto_upgrade_strategies',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False),
sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False),
sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False),
sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'),
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tenant_plugin_auto_upgrade_strategies')
# ### end Alembic commands ###

View File

@@ -12,7 +12,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4474872b0ee6'
down_revision = '2adcbe1f5dfb'
down_revision = '16081485540c'
branch_labels = None
depends_on = None

View File

@@ -0,0 +1,62 @@
"""tool oauth
Revision ID: 71f5020c6470
Revises: 4474872b0ee6
Create Date: 2025-06-24 17:05:43.118647
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '71f5020c6470'
down_revision = '1c9ba48be8e4'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tool_oauth_system_clients',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('plugin_id', sa.String(length=512), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
)
op.create_table('tool_oauth_tenant_clients',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('plugin_id', sa.String(length=512), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client')
)
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False))
batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False))
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider'])
batch_op.drop_column('credential_type')
batch_op.drop_column('is_default')
batch_op.drop_column('name')
op.drop_table('tool_oauth_tenant_clients')
op.drop_table('tool_oauth_system_clients')
# ### end Alembic commands ###

View File

@@ -21,6 +21,43 @@ from .model import Account, App, Tenant
from .types import StringUUID
# system level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthSystemClient(Base):
__tablename__ = "tool_oauth_system_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
# tenant level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthTenantClient(Base):
__tablename__ = "tool_oauth_tenant_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
@property
def oauth_params(self) -> dict:
return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
class BuiltinToolProvider(Base):
"""
This table stores the tool provider information for built-in tools for each tenant.
@@ -29,12 +66,14 @@ class BuiltinToolProvider(Base):
__tablename__ = "tool_builtin_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
# one tenant can only have one tool provider with the same name
db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"),
db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
)
# id of the tool provider
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(
db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying")
)
# id of the tenant
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
# who created this tool provider
@@ -49,6 +88,11 @@ class BuiltinToolProvider(Base):
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
# credential type, e.g., "api-key", "oauth2"
credential_type: Mapped[str] = mapped_column(
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
)
@property
def credentials(self) -> dict:
@@ -68,7 +112,7 @@ class ApiToolProvider(Base):
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# name of the api provider
name = db.Column(db.String(255), nullable=False)
name = db.Column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying"))
# icon
icon = db.Column(db.String(255), nullable=False)
# original schema
@@ -281,18 +325,19 @@ class MCPToolProvider(Base):
@property
def decrypted_credentials(self) -> dict:
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_provider_encrypter
provider_controller = MCPToolProviderController._from_db(self)
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
cache=NoOpProviderCredentialCache(),
)
return tool_configuration.decrypt(self.credentials, use_cache=False)
return encrypter.decrypt(self.credentials) # type: ignore
class ToolModelInvoke(Base):

View File

@@ -575,13 +575,26 @@ class AppDslService:
raise ValueError("Missing draft workflow configuration, please check.")
workflow_dict = workflow.to_dict(include_secret=include_secret)
# TODO: refactor: we need a better way to filter workspace related data from nodes
for node in workflow_dict.get("graph", {}).get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
node_data = node.get("data", {})
if not node_data:
continue
data_type = node_data.get("type", "")
if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node_data.get("dataset_ids", [])
node_data["dataset_ids"] = [
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
for dataset_id in dataset_ids
]
# filter credential id from tool node
if not include_secret and data_type == NodeType.TOOL.value:
node_data.pop("credential_id", None)
# filter credential id from agent node
if not include_secret and data_type == NodeType.AGENT.value:
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
tool.pop("credential_id", None)
export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
@@ -602,7 +615,15 @@ class AppDslService:
if not app_model_config:
raise ValueError("Missing app configuration, please check.")
export_data["model_config"] = app_model_config.to_dict()
model_config = app_model_config.to_dict()
# TODO: refactor: we need a better way to filter workspace related data from model config
# filter credential id from model config
for tool in model_config.get("agent_mode", {}).get("tools", []):
tool.pop("credential_id", None)
export_data["model_config"] = model_config
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())

View File

@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.impl.dynamic_select import DynamicSelectClient
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_tool_provider_encrypter
from extensions.ext_database import db
from models.tools import BuiltinToolProvider
@@ -38,11 +38,9 @@ class PluginParameterService:
case "tool":
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
# check if credentials are required
@@ -63,7 +61,7 @@ class PluginParameterService:
if db_record is None:
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
credentials = tool_configuration.decrypt(db_record.credentials)
credentials = encrypter.decrypt(db_record.credentials)
case _:
raise ValueError(f"Invalid provider type: {provider_type}")

View File

@@ -196,6 +196,17 @@ class PluginService:
manager = PluginInstaller()
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
@staticmethod
def is_plugin_verified(tenant_id: str, plugin_unique_identifier: str) -> bool:
"""
Check if the plugin is verified
"""
manager = PluginInstaller()
try:
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier).verified
except Exception:
return False
@staticmethod
def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
"""

View File

@@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_tool_provider_encrypter
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
from models.tools import ApiToolProvider
@@ -164,15 +164,11 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
encrypted_credentials = tool_configuration.encrypt(credentials)
db_provider.credentials_str = json.dumps(encrypted_credentials)
db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
db.session.add(db_provider)
db.session.commit()
@@ -297,28 +293,26 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists
tool_configuration = ProviderConfigEncrypter(
encrypter, cache = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
original_credentials = encrypter.decrypt(provider.credentials)
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
credentials = tool_configuration.encrypt(credentials)
credentials = encrypter.encrypt(credentials)
provider.credentials_str = json.dumps(credentials)
db.session.add(provider)
db.session.commit()
# delete cache
tool_configuration.delete_tool_credentials_cache()
cache.delete()
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
@@ -416,15 +410,13 @@ class ApiToolManageService:
# decrypt credentials
if db_provider.id:
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
decrypted_credentials = encrypter.decrypt(credentials)
# check if the credential has changed, save the original credential
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = decrypted_credentials[name]
@@ -446,7 +438,7 @@ class ApiToolManageService:
return {"result": result or "empty response"}
@staticmethod
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
"""
list api tools
"""
@@ -474,7 +466,7 @@ class ApiToolManageService:
for tool in tools or []:
user_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
tenant_id=tenant_id, tool=tool, labels=labels
)
)

View File

@@ -1,28 +1,84 @@
import json
import logging
import re
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.entities.api_entities import (
ToolApiEntity,
ToolProviderApiEntity,
ToolProviderCredentialApiEntity,
ToolProviderCredentialInfoApiEntity,
)
from core.tools.entities.tool_entities import CredentialType
from core.tools.errors import ToolProviderNotFoundError
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_provider_encrypter
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from extensions.ext_database import db
from models.tools import BuiltinToolProvider
from extensions.ext_redis import redis_client
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class BuiltinToolManageService:
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
@staticmethod
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
"""
delete custom oauth client params
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine) as session:
session.query(ToolOAuthTenantClient).filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
).delete()
session.commit()
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str):
"""
get builtin tool provider oauth client schema
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
tenant_id, provider.plugin_unique_identifier
)
is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled(
tenant_id, provider_name
)
is_system_oauth_params_exists = verified and BuiltinToolManageService.is_oauth_system_client_exists(
provider_name
)
result = {
"schema": provider.get_oauth_client_schema(),
"is_oauth_custom_client_enabled": is_oauth_custom_client_enabled,
"is_system_oauth_params_exists": is_system_oauth_params_exists,
"client_params": BuiltinToolManageService.get_custom_oauth_client_params(tenant_id, provider_name),
"redirect_uri": f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback",
}
return result
@staticmethod
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
"""
@@ -36,27 +92,11 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tools = provider_controller.get_tools()
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
result: list[ToolApiEntity] = []
for tool in tools or []:
result.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
@@ -65,25 +105,15 @@ class BuiltinToolManageService:
return result
@staticmethod
def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
def get_builtin_tool_provider_info(tenant_id: str, provider: str):
"""
get builtin tool provider info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
if builtin_provider is None:
raise ValueError(f"you have not added provider {provider}")
entity = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
@@ -92,128 +122,407 @@ class BuiltinToolManageService:
)
entity.original_credentials = {}
return entity
@staticmethod
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str):
"""
list builtin provider credentials schema
:param credential_type: credential type
:param provider_name: the name of the provider
:param tenant_id: the id of the tenant
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return jsonable_encoder(provider.get_credentials_schema())
return provider.get_credentials_schema_by_type(credential_type)
@staticmethod
def update_builtin_tool_provider(
session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
user_id: str,
tenant_id: str,
provider: str,
credential_id: str,
credentials: dict | None = None,
name: str | None = None,
):
"""
update builtin tool provider
"""
# get if the provider exists
provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
try:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
with Session(db.engine) as session:
# get if the provider exists
db_provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
# get original credentials if exists
if provider is not None:
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
# validate credentials
provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials
credentials = tool_configuration.encrypt(credentials)
except (
PluginDaemonClientSideError,
ToolProviderNotFoundError,
ToolNotFoundError,
ToolProviderCredentialValidationError,
) as e:
raise ValueError(str(e))
try:
if CredentialType.of(db_provider.credential_type).is_editable() and credentials:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials")
if provider is None:
# create provider
provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
encrypted_credentials=json.dumps(credentials),
)
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, db_provider, provider, provider_controller
)
db.session.add(provider)
else:
provider.encrypted_credentials = json.dumps(credentials)
original_credentials = encrypter.decrypt(db_provider.credentials)
new_credentials: dict = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
# delete cache
tool_configuration.delete_tool_credentials_cache()
if CredentialType.of(db_provider.credential_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, new_credentials)
db.session.commit()
# encrypt credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
cache.delete()
# update name if provided
if name and name != db_provider.name:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
):
raise ValueError(f"the credential name '{name}' is already used")
db_provider.name = name
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
def add_builtin_tool_provider(
user_id: str,
api_type: CredentialType,
tenant_id: str,
provider: str,
credentials: dict,
name: str | None = None,
):
"""
add builtin tool provider
"""
try:
with Session(db.engine) as session:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials")
provider_count = (
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
)
# check if the provider count is reached the limit
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
raise ValueError(f"you have reached the maximum number of providers for {provider}")
# validate credentials if allowed
if CredentialType.of(api_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, credentials)
# generate name if not provided
if name is None or name == "":
name = BuiltinToolManageService.generate_builtin_tool_provider_name(
session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
)
else:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
):
raise ValueError(f"the credential name '{name}' is already used")
# create encrypter
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(api_type)
],
cache=NoOpProviderCredentialCache(),
)
db_provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value,
name=name,
)
session.add(db_provider)
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@staticmethod
def create_tool_encrypter(
tenant_id: str,
db_provider: BuiltinToolProvider,
provider: str,
provider_controller: BuiltinToolProviderController,
):
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
],
cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
)
return encrypter, cache
@staticmethod
def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
try:
db_providers = (
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type.value,
)
.order_by(BuiltinToolProvider.created_at.desc())
.all()
)
# Get the default name pattern
default_pattern = f"{credential_type.get_name()}"
# Find all names that match the default pattern: "{default_pattern} {number}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for db_provider in db_providers:
if db_provider.name:
match = re.match(pattern, db_provider.name.strip())
if match:
numbers.append(int(match.group(1)))
# If no default pattern names found, start with 1
if not numbers:
return f"{default_pattern} 1"
# Find the next number
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
# fallback
return f"{credential_type.get_name()} 1"
@staticmethod
def get_builtin_tool_provider_credentials(
tenant_id: str, provider_name: str
) -> list[ToolProviderCredentialApiEntity]:
"""
get builtin tool provider credentials
"""
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
with db.session.no_autoflush:
providers = (
db.session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider_name)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.all()
)
if provider_obj is None:
return {}
if len(providers) == 0:
return []
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
credentials = tool_configuration.decrypt(provider_obj.credentials)
credentials = tool_configuration.mask_tool_credentials(credentials)
return credentials
default_provider = providers[0]
default_provider.is_default = True
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
credentials: list[ToolProviderCredentialApiEntity] = []
encrypters = {}
for provider in providers:
credential_type = provider.credential_type
if credential_type not in encrypters:
encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter(
tenant_id, provider, provider.provider, provider_controller
)[0]
encrypter = encrypters[credential_type]
decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=decrypt_credential,
)
credentials.append(credential_entity)
return credentials
@staticmethod
def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity:
"""
get builtin tool provider credential info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
supported_credential_types = provider_controller.get_supported_credential_types()
credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
credential_info = ToolProviderCredentialInfoApiEntity(
supported_credential_types=supported_credential_types,
is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
credentials=credentials,
)
return credential_info
@staticmethod
def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
"""
delete tool provider
"""
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
with Session(db.engine) as session:
db_provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if provider_obj is None:
raise ValueError(f"you have not added provider {provider_name}")
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
db.session.delete(provider_obj)
db.session.commit()
session.delete(db_provider)
session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration.delete_tool_credentials_cache()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
_, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, db_provider, provider, provider_controller
)
cache.delete()
return {"result": "success"}
@staticmethod
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
"""
set default provider
"""
with Session(db.engine) as session:
# get provider
target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
).update({"is_default": False})
# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}
@staticmethod
def is_oauth_system_client_exists(provider_name: str) -> bool:
"""
check if oauth system client exists
"""
tool_provider = ToolProviderID(provider_name)
with Session(db.engine).no_autoflush as session:
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
)
return system_client is not None
@staticmethod
def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
"""
check if oauth custom client is enabled
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine).no_autoflush as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
)
.first()
)
return user_client is not None and user_client.enabled
@staticmethod
def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None:
"""
get builtin tool provider
"""
tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
with Session(db.engine).no_autoflush as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
)
.first()
)
oauth_params: Mapping[str, Any] | None = None
if user_client:
oauth_params = encrypter.decrypt(user_client.oauth_params)
return oauth_params
# only verified provider can use custom oauth client
is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
tenant_id, provider.plugin_unique_identifier
)
if not is_verified:
return oauth_params
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
)
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")
return oauth_params
@staticmethod
def get_builtin_tool_provider_icon(provider: str):
"""
@@ -234,9 +543,7 @@ class BuiltinToolManageService:
with db.session.no_autoflush:
# get all user added providers
db_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
)
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
# rewrite db_providers
for db_provider in db_providers:
@@ -275,7 +582,6 @@ class BuiltinToolManageService:
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
@@ -287,43 +593,153 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider_obj = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider
"""
with Session(db.engine) as session:
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
.first()
)
else:
provider_obj = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name),
else:
provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name),
)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
if provider is None:
return None
provider.provider = ToolProviderID(provider.provider).to_string()
return provider
except Exception:
# it's an old provider without organization
return (
session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
if provider_obj is None:
return None
@staticmethod
def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: Optional[dict] = None,
enable_oauth_custom_client: Optional[bool] = None,
):
"""
setup oauth custom client
"""
if client_params is None and enable_oauth_custom_client is None:
return {"result": "success"}
provider_obj.provider = ToolProviderID(provider_obj.provider).to_string()
return provider_obj
except Exception:
# it's an old provider without organization
return (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name),
tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
with Session(db.engine) as session:
custom_client_params = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
.first()
)
# if the record does not exist, create a basic record
if custom_client_params is None:
custom_client_params = ToolOAuthTenantClient(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
session.add(custom_client_params)
if client_params is not None:
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
original_params = encrypter.decrypt(custom_client_params.oauth_params)
new_params: dict = {
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}
custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
if enable_oauth_custom_client is not None:
custom_client_params.enabled = enable_oauth_custom_client
session.commit()
return {"result": "success"}
@staticmethod
def get_custom_oauth_client_params(tenant_id: str, provider: str):
"""
get custom oauth client params
"""
with Session(db.engine) as session:
tool_provider = ToolProviderID(provider)
custom_oauth_client_params: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
.first()
)
if custom_oauth_client_params is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
if not isinstance(provider_controller, BuiltinToolProviderController):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))

View File

@@ -7,13 +7,14 @@ from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import ProviderConfigEncrypter
from extensions.ext_database import db
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
@@ -69,6 +70,7 @@ class MCPToolManageService:
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
MCPToolProvider.tenant_id == tenant_id,
)
.first()
)
@@ -197,8 +199,7 @@ class MCPToolManageService:
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
provider_config_cache=NoOpProviderCredentialCache(),
)
credentials = tool_configuration.encrypt(credentials)
mcp_provider.updated_at = datetime.now()

View File

@@ -5,21 +5,23 @@ from typing import Any, Optional, Union, cast
from yarl import URL
from configs import dify_config
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
CredentialType,
ToolParameter,
ToolProviderType,
)
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
@@ -119,7 +121,12 @@ class ToolTransformService:
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
# get credentials schema
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
schema = {
x.to_basic_provider_config().name: x
for x in provider_controller.get_credentials_schema_by_type(
CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
)
}
for name, value in schema.items():
if result.masked_credentials:
@@ -136,15 +143,23 @@ class ToolTransformService:
credentials = db_provider.credentials
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(
CredentialType.of(db_provider.credential_type)
)
],
cache=ToolProviderCredentialsCache(
tenant_id=db_provider.tenant_id,
provider=db_provider.provider,
credential_id=db_provider.id,
),
)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt(data=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
result.original_credentials = decrypted_credentials
@@ -287,16 +302,14 @@ class ToolTransformService:
if decrypt_credentials:
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt(data=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
@@ -306,7 +319,6 @@ class ToolTransformService:
def convert_tool_entity_to_api_entity(
tool: Union[ApiToolBundle, WorkflowTool, Tool],
tenant_id: str,
credentials: dict | None = None,
labels: list[str] | None = None,
) -> ToolApiEntity:
"""
@@ -316,7 +328,7 @@ class ToolTransformService:
# fork tool runtime
tool = tool.fork_tool_runtime(
runtime=ToolRuntime(
credentials=credentials or {},
credentials={},
tenant_id=tenant_id,
)
)
@@ -357,6 +369,19 @@ class ToolTransformService:
labels=labels or [],
)
@staticmethod
def convert_builtin_provider_to_credential_entity(
provider: BuiltinToolProvider, credentials: dict
) -> ToolProviderCredentialApiEntity:
return ToolProviderCredentialApiEntity(
id=provider.id,
name=provider.name,
provider=provider.provider,
credential_type=CredentialType.of(provider.credential_type),
is_default=provider.is_default,
credentials=credentials,
)
@staticmethod
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
"""

View File

@@ -0,0 +1,619 @@
import base64
import hashlib
from unittest.mock import patch
import pytest
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad
from core.tools.utils.system_oauth_encryption import (
OAuthEncryptionError,
SystemOAuthEncrypter,
create_system_oauth_encrypter,
decrypt_system_oauth_params,
encrypt_system_oauth_params,
get_system_oauth_encrypter,
)
class TestSystemOAuthEncrypter:
"""Test cases for SystemOAuthEncrypter class"""
def test_init_with_secret_key(self):
"""Test initialization with provided secret key"""
secret_key = "test_secret_key"
encrypter = SystemOAuthEncrypter(secret_key=secret_key)
expected_key = hashlib.sha256(secret_key.encode()).digest()
assert encrypter.key == expected_key
def test_init_with_none_secret_key(self):
"""Test initialization with None secret key falls back to config"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = SystemOAuthEncrypter(secret_key=None)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
def test_init_with_empty_secret_key(self):
"""Test initialization with empty secret key"""
encrypter = SystemOAuthEncrypter(secret_key="")
expected_key = hashlib.sha256(b"").digest()
assert encrypter.key == expected_key
def test_init_without_secret_key_uses_config(self):
"""Test initialization without secret key uses config"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "default_secret"
encrypter = SystemOAuthEncrypter()
expected_key = hashlib.sha256(b"default_secret").digest()
assert encrypter.key == expected_key
def test_encrypt_oauth_params_basic(self):
"""Test basic OAuth parameters encryption"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
# Should be valid base64
try:
base64.b64decode(encrypted)
except Exception:
pytest.fail("Encrypted result is not valid base64")
def test_encrypt_oauth_params_empty_dict(self):
"""Test encryption with empty dictionary"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_complex_data(self):
"""Test encryption with complex data structures"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {
"client_id": "test_id",
"client_secret": "test_secret",
"scopes": ["read", "write", "admin"],
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
"numeric_value": 42,
"boolean_value": False,
"null_value": None,
}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_unicode_data(self):
"""Test encryption with unicode data"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_large_data(self):
"""Test encryption with large data"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {
"client_id": "test_id",
"large_data": "x" * 10000, # 10KB of data
}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_invalid_input(self):
"""Test encryption with invalid input types"""
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_oauth_params(None) # type: ignore
with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_oauth_params("not_a_dict") # type: ignore
def test_decrypt_oauth_params_basic(self):
"""Test basic OAuth parameters decryption"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_empty_dict(self):
"""Test decryption of empty dictionary"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_complex_data(self):
"""Test decryption with complex data structures"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"client_secret": "test_secret",
"scopes": ["read", "write", "admin"],
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
"numeric_value": 42,
"boolean_value": False,
"null_value": None,
}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_unicode_data(self):
"""Test decryption with unicode data"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"client_secret": "test_secret",
"description": "This is a test case 🚀",
}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_large_data(self):
"""Test decryption with large data"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"large_data": "x" * 10000, # 10KB of data
}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_invalid_base64(self):
"""Test decryption with invalid base64 data"""
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(OAuthEncryptionError):
encrypter.decrypt_oauth_params("invalid_base64!")
def test_decrypt_oauth_params_empty_string(self):
"""Test decryption with empty string"""
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params("")
assert "encrypted_data cannot be empty" in str(exc_info.value)
def test_decrypt_oauth_params_non_string_input(self):
"""Test decryption with non-string input"""
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(123) # type: ignore
assert "encrypted_data must be a string" in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(None) # type: ignore
assert "encrypted_data must be a string" in str(exc_info.value)
def test_decrypt_oauth_params_too_short_data(self):
"""Test decryption with too short encrypted data"""
encrypter = SystemOAuthEncrypter("test_secret")
# Create data that's too short (less than 32 bytes)
short_data = base64.b64encode(b"short").decode()
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.decrypt_oauth_params(short_data)
assert "Invalid encrypted data format" in str(exc_info.value)
def test_decrypt_oauth_params_corrupted_data(self):
"""Test decryption with corrupted data"""
encrypter = SystemOAuthEncrypter("test_secret")
# Create corrupted data (valid base64 but invalid encrypted content)
corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage
with pytest.raises(OAuthEncryptionError):
encrypter.decrypt_oauth_params(corrupted_data)
def test_decrypt_oauth_params_wrong_key(self):
"""Test decryption with wrong key"""
encrypter1 = SystemOAuthEncrypter("secret1")
encrypter2 = SystemOAuthEncrypter("secret2")
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter1.encrypt_oauth_params(original_params)
with pytest.raises(OAuthEncryptionError):
encrypter2.decrypt_oauth_params(encrypted)
def test_encryption_decryption_consistency(self):
"""Test that encryption and decryption are consistent"""
encrypter = SystemOAuthEncrypter("test_secret")
test_cases = [
{},
{"simple": "value"},
{"client_id": "id", "client_secret": "secret"},
{"complex": {"nested": {"deep": "value"}}},
{"unicode": "test 🚀"},
{"numbers": 42, "boolean": True, "null": None},
{"array": [1, 2, 3, "four", {"five": 5}]},
]
for original_params in test_cases:
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params, f"Failed for case: {original_params}"
def test_encryption_randomness(self):
"""Test that encryption produces different results for same input"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted1 = encrypter.encrypt_oauth_params(oauth_params)
encrypted2 = encrypter.encrypt_oauth_params(oauth_params)
# Should be different due to random IV
assert encrypted1 != encrypted2
# But should decrypt to same result
decrypted1 = encrypter.decrypt_oauth_params(encrypted1)
decrypted2 = encrypter.decrypt_oauth_params(encrypted2)
assert decrypted1 == decrypted2 == oauth_params
def test_different_secret_keys_produce_different_results(self):
"""Test that different secret keys produce different encrypted results"""
encrypter1 = SystemOAuthEncrypter("secret1")
encrypter2 = SystemOAuthEncrypter("secret2")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted1 = encrypter1.encrypt_oauth_params(oauth_params)
encrypted2 = encrypter2.encrypt_oauth_params(oauth_params)
# Should produce different encrypted results
assert encrypted1 != encrypted2
# But each should decrypt correctly with its own key
decrypted1 = encrypter1.decrypt_oauth_params(encrypted1)
decrypted2 = encrypter2.decrypt_oauth_params(encrypted2)
assert decrypted1 == decrypted2 == oauth_params
@patch("core.tools.utils.system_oauth_encryption.get_random_bytes")
def test_encrypt_oauth_params_crypto_error(self, mock_get_random_bytes):
"""Test encryption when crypto operation fails"""
mock_get_random_bytes.side_effect = Exception("Crypto error")
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id"}
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.encrypt_oauth_params(oauth_params)
assert "Encryption failed" in str(exc_info.value)
@patch("core.tools.utils.system_oauth_encryption.TypeAdapter")
def test_encrypt_oauth_params_serialization_error(self, mock_type_adapter):
"""Test encryption when JSON serialization fails"""
mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error")
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id"}
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.encrypt_oauth_params(oauth_params)
assert "Encryption failed" in str(exc_info.value)
def test_decrypt_oauth_params_invalid_json(self):
"""Test decryption with invalid JSON data"""
encrypter = SystemOAuthEncrypter("test_secret")
# Create valid encrypted data but with invalid JSON content
iv = get_random_bytes(16)
cipher = AES.new(encrypter.key, AES.MODE_CBC, iv)
invalid_json = b"invalid json content"
padded_data = pad(invalid_json, AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
combined = iv + encrypted_data
encoded = base64.b64encode(combined).decode()
with pytest.raises(OAuthEncryptionError):
encrypter.decrypt_oauth_params(encoded)
def test_key_derivation_consistency(self):
"""Test that key derivation is consistent"""
secret_key = "test_secret"
encrypter1 = SystemOAuthEncrypter(secret_key)
encrypter2 = SystemOAuthEncrypter(secret_key)
assert encrypter1.key == encrypter2.key
# Keys should be 32 bytes (256 bits)
assert len(encrypter1.key) == 32
class TestFactoryFunctions:
"""Test cases for factory functions"""
def test_create_system_oauth_encrypter_with_secret(self):
"""Test factory function with secret key"""
secret_key = "test_secret"
encrypter = create_system_oauth_encrypter(secret_key)
assert isinstance(encrypter, SystemOAuthEncrypter)
expected_key = hashlib.sha256(secret_key.encode()).digest()
assert encrypter.key == expected_key
def test_create_system_oauth_encrypter_without_secret(self):
"""Test factory function without secret key"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = create_system_oauth_encrypter()
assert isinstance(encrypter, SystemOAuthEncrypter)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
def test_create_system_oauth_encrypter_with_none_secret(self):
"""Test factory function with None secret key"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = create_system_oauth_encrypter(None)
assert isinstance(encrypter, SystemOAuthEncrypter)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
class TestGlobalEncrypterInstance:
"""Test cases for global encrypter instance"""
def test_get_system_oauth_encrypter_singleton(self):
"""Test that get_system_oauth_encrypter returns singleton instance"""
# Clear the global instance first
import core.tools.utils.system_oauth_encryption
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
encrypter1 = get_system_oauth_encrypter()
encrypter2 = get_system_oauth_encrypter()
assert encrypter1 is encrypter2
assert isinstance(encrypter1, SystemOAuthEncrypter)
def test_get_system_oauth_encrypter_uses_config(self):
"""Test that global encrypter uses config"""
# Clear the global instance first
import core.tools.utils.system_oauth_encryption
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "global_secret"
encrypter = get_system_oauth_encrypter()
expected_key = hashlib.sha256(b"global_secret").digest()
assert encrypter.key == expected_key
class TestConvenienceFunctions:
"""Test cases for convenience functions"""
def test_encrypt_system_oauth_params(self):
"""Test encrypt_system_oauth_params convenience function"""
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypt_system_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_decrypt_system_oauth_params(self):
"""Test decrypt_system_oauth_params convenience function"""
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypt_system_oauth_params(oauth_params)
decrypted = decrypt_system_oauth_params(encrypted)
assert decrypted == oauth_params
def test_convenience_functions_consistency(self):
"""Test that convenience functions work consistently"""
test_cases = [
{},
{"simple": "value"},
{"client_id": "id", "client_secret": "secret"},
{"complex": {"nested": {"deep": "value"}}},
{"unicode": "test 🚀"},
{"numbers": 42, "boolean": True, "null": None},
]
for original_params in test_cases:
encrypted = encrypt_system_oauth_params(original_params)
decrypted = decrypt_system_oauth_params(encrypted)
assert decrypted == original_params, f"Failed for case: {original_params}"
def test_convenience_functions_with_errors(self):
"""Test convenience functions with error conditions"""
# Test encryption with invalid input
with pytest.raises(Exception): # noqa: B017
encrypt_system_oauth_params(None) # type: ignore
# Test decryption with invalid input
with pytest.raises(ValueError):
decrypt_system_oauth_params("")
with pytest.raises(ValueError):
decrypt_system_oauth_params(None) # type: ignore
class TestErrorHandling:
"""Test cases for error handling"""
def test_oauth_encryption_error_inheritance(self):
"""Test that OAuthEncryptionError is a proper exception"""
error = OAuthEncryptionError("Test error")
assert isinstance(error, Exception)
assert str(error) == "Test error"
def test_oauth_encryption_error_with_cause(self):
"""Test OAuthEncryptionError with cause"""
original_error = ValueError("Original error")
error = OAuthEncryptionError("Wrapper error")
error.__cause__ = original_error
assert isinstance(error, Exception)
assert str(error) == "Wrapper error"
assert error.__cause__ is original_error
def test_error_messages_are_informative(self):
"""Test that error messages are informative"""
encrypter = SystemOAuthEncrypter("test_secret")
# Test empty string error
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params("")
assert "encrypted_data cannot be empty" in str(exc_info.value)
# Test non-string error
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(123) # type: ignore
assert "encrypted_data must be a string" in str(exc_info.value)
# Test invalid format error
short_data = base64.b64encode(b"short").decode()
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.decrypt_oauth_params(short_data)
assert "Invalid encrypted data format" in str(exc_info.value)
class TestEdgeCases:
"""Test cases for edge cases and boundary conditions"""
def test_very_long_secret_key(self):
"""Test with very long secret key"""
long_secret = "x" * 10000
encrypter = SystemOAuthEncrypter(long_secret)
# Key should still be 32 bytes due to SHA-256
assert len(encrypter.key) == 32
# Should still work normally
oauth_params = {"client_id": "test_id"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_special_characters_in_secret_key(self):
"""Test with special characters in secret key"""
special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀"
encrypter = SystemOAuthEncrypter(special_secret)
oauth_params = {"client_id": "test_id"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_empty_values_in_oauth_params(self):
"""Test with empty values in oauth params"""
oauth_params = {
"client_id": "",
"client_secret": "",
"empty_dict": {},
"empty_list": [],
"empty_string": "",
"zero": 0,
"false": False,
"none": None,
}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_deeply_nested_oauth_params(self):
"""Test with deeply nested oauth params"""
oauth_params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_oauth_params_with_all_json_types(self):
"""Test with all JSON-supported data types"""
oauth_params = {
"string": "test_string",
"integer": 42,
"float": 3.14159,
"boolean_true": True,
"boolean_false": False,
"null_value": None,
"empty_string": "",
"array": [1, "two", 3.0, True, False, None],
"object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True},
}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
class TestPerformance:
"""Test cases for performance considerations"""
def test_large_oauth_params(self):
"""Test with large oauth params"""
large_value = "x" * 100000 # 100KB
oauth_params = {"client_id": "test_id", "large_data": large_value}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_many_fields_oauth_params(self):
"""Test with many fields in oauth params"""
oauth_params = {f"field_{i}": f"value_{i}" for i in range(1000)}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_repeated_encryption_decryption(self):
"""Test repeated encryption and decryption operations"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
# Test multiple rounds of encryption/decryption
for i in range(100):
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params

View File

@@ -18,7 +18,6 @@ import AppIcon from '@/app/components/base/app-icon'
import Button from '@/app/components/base/button'
import Indicator from '@/app/components/header/indicator'
import Switch from '@/app/components/base/switch'
import Toast from '@/app/components/base/toast'
import ConfigContext from '@/context/debug-configuration'
import type { AgentTool } from '@/types/app'
import { type Collection, CollectionType } from '@/app/components/tools/types'
@@ -26,8 +25,6 @@ import { MAX_TOOLS_NUM } from '@/config'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import Tooltip from '@/app/components/base/tooltip'
import { DefaultToolIcon } from '@/app/components/base/icons/src/public/other'
import ConfigCredential from '@/app/components/tools/setting/build-in/config-credentials'
import { updateBuiltInToolCredential } from '@/service/tools'
import cn from '@/utils/classnames'
import ToolPicker from '@/app/components/workflow/block-selector/tool-picker'
import type { ToolDefaultValue, ToolValue } from '@/app/components/workflow/block-selector/types'
@@ -57,13 +54,7 @@ const AgentTools: FC = () => {
const formattingChangedDispatcher = useFormattingChangedDispatcher()
const [currentTool, setCurrentTool] = useState<AgentToolWithMoreInfo>(null)
const currentCollection = useMemo(() => {
if (!currentTool) return null
const collection = collectionList.find(collection => canFindTool(collection.id, currentTool?.provider_id) && collection.type === currentTool?.provider_type)
return collection
}, [currentTool, collectionList])
const [isShowSettingTool, setIsShowSettingTool] = useState(false)
const [isShowSettingAuth, setShowSettingAuth] = useState(false)
const tools = (modelConfig?.agentConfig?.tools as AgentTool[] || []).map((item) => {
const collection = collectionList.find(
collection =>
@@ -100,17 +91,6 @@ const AgentTools: FC = () => {
formattingChangedDispatcher()
}
const handleToolAuthSetting = (value: AgentToolWithMoreInfo) => {
const newModelConfig = produce(modelConfig, (draft) => {
const tool = (draft.agentConfig.tools).find((item: any) => item.provider_id === value?.collection?.id && item.tool_name === value?.tool_name)
if (tool)
(tool as AgentTool).notAuthor = false
})
setModelConfig(newModelConfig)
setIsShowSettingTool(false)
formattingChangedDispatcher()
}
const [isDeleting, setIsDeleting] = useState<number>(-1)
const getToolValue = (tool: ToolDefaultValue) => {
return {
@@ -144,6 +124,20 @@ const AgentTools: FC = () => {
return item.provider_name
}
const handleAuthorizationItemClick = useCallback((credentialId: string) => {
const newModelConfig = produce(modelConfig, (draft) => {
const tool = (draft.agentConfig.tools).find((item: any) => item.provider_id === currentTool?.provider_id)
if (tool)
(tool as AgentTool).credential_id = credentialId
})
setCurrentTool({
...currentTool,
credential_id: credentialId,
} as any)
setModelConfig(newModelConfig)
formattingChangedDispatcher()
}, [currentTool, modelConfig, setModelConfig, formattingChangedDispatcher])
return (
<>
<Panel
@@ -299,7 +293,7 @@ const AgentTools: FC = () => {
{item.notAuthor && (
<Button variant='secondary' size='small' onClick={() => {
setCurrentTool(item)
setShowSettingAuth(true)
setIsShowSettingTool(true)
}}>
{t('tools.notAuthorized')}
<Indicator className='ml-2' color='orange' />
@@ -319,21 +313,8 @@ const AgentTools: FC = () => {
isModel={currentTool?.collection?.type === CollectionType.model}
onSave={handleToolSettingChange}
onHide={() => setIsShowSettingTool(false)}
/>
)}
{isShowSettingAuth && (
<ConfigCredential
collection={currentCollection as any}
onCancel={() => setShowSettingAuth(false)}
onSaved={async (value) => {
await updateBuiltInToolCredential((currentCollection as any).name, value)
Toast.notify({
type: 'success',
message: t('common.api.actionSuccess'),
})
handleToolAuthSetting(currentTool)
setShowSettingAuth(false)
}}
credentialId={currentTool?.credential_id}
onAuthorizationItemClick={handleAuthorizationItemClick}
/>
)}
</>

View File

@@ -14,7 +14,6 @@ import Icon from '@/app/components/plugins/card/base/card-icon'
import OrgInfo from '@/app/components/plugins/card/base/org-info'
import Description from '@/app/components/plugins/card/base/description'
import TabSlider from '@/app/components/base/tab-slider-plain'
import Button from '@/app/components/base/button'
import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form'
import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
@@ -25,6 +24,10 @@ import I18n from '@/context/i18n'
import { getLanguage } from '@/i18n/language'
import cn from '@/utils/classnames'
import type { ToolWithProvider } from '@/app/components/workflow/types'
import {
AuthCategory,
PluginAuthInAgent,
} from '@/app/components/plugins/plugin-auth'
type Props = {
showBackButton?: boolean
@@ -36,6 +39,8 @@ type Props = {
readonly?: boolean
onHide: () => void
onSave?: (value: Record<string, any>) => void
credentialId?: string
onAuthorizationItemClick?: (id: string) => void
}
const SettingBuiltInTool: FC<Props> = ({
@@ -48,6 +53,8 @@ const SettingBuiltInTool: FC<Props> = ({
readonly,
onHide,
onSave,
credentialId,
onAuthorizationItemClick,
}) => {
const { locale } = useContext(I18n)
const language = getLanguage(locale)
@@ -197,8 +204,20 @@ const SettingBuiltInTool: FC<Props> = ({
</div>
<div className='system-md-semibold mt-1 text-text-primary'>{currTool?.label[language]}</div>
{!!currTool?.description[language] && (
<Description className='mt-3' text={currTool.description[language]} descriptionLineRows={2}></Description>
<Description className='mb-2 mt-3 h-auto' text={currTool.description[language]} descriptionLineRows={2}></Description>
)}
{
collection.allow_delete && collection.type === CollectionType.builtIn && (
<PluginAuthInAgent
pluginPayload={{
provider: collection.name,
category: AuthCategory.tool,
}}
credentialId={credentialId}
onAuthorizationItemClick={onAuthorizationItemClick}
/>
)
}
</div>
{/* form */}
<div className='h-full'>

View File

@@ -117,7 +117,7 @@ const Question: FC<QuestionProps> = ({
</div>
<div
ref={contentRef}
className='w-full rounded-2xl bg-background-gradient-bg-fill-chat-bubble-bg-3 px-4 py-3 text-sm text-text-primary'
className='bg-background-gradient-bg-fill-chat-bubble-bg-3 w-full rounded-2xl px-4 py-3 text-sm text-text-primary'
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}
>
{

View File

@@ -0,0 +1,177 @@
import {
isValidElement,
memo,
useMemo,
} from 'react'
import type { AnyFieldApi } from '@tanstack/react-form'
import { useStore } from '@tanstack/react-form'
import cn from '@/utils/classnames'
import Input from '@/app/components/base/input'
import PureSelect from '@/app/components/base/select/pure'
import type { FormSchema } from '@/app/components/base/form/types'
import { FormTypeEnum } from '@/app/components/base/form/types'
import { useRenderI18nObject } from '@/hooks/use-i18n'
export type BaseFieldProps = {
fieldClassName?: string
labelClassName?: string
inputContainerClassName?: string
inputClassName?: string
formSchema: FormSchema
field: AnyFieldApi
disabled?: boolean
}
const BaseField = ({
fieldClassName,
labelClassName,
inputContainerClassName,
inputClassName,
formSchema,
field,
disabled,
}: BaseFieldProps) => {
const renderI18nObject = useRenderI18nObject()
const {
label,
required,
placeholder,
options,
labelClassName: formLabelClassName,
show_on = [],
} = formSchema
const memorizedLabel = useMemo(() => {
if (isValidElement(label))
return label
if (typeof label === 'string')
return label
if (typeof label === 'object' && label !== null)
return renderI18nObject(label as Record<string, string>)
}, [label, renderI18nObject])
const memorizedPlaceholder = useMemo(() => {
if (typeof placeholder === 'string')
return placeholder
if (typeof placeholder === 'object' && placeholder !== null)
return renderI18nObject(placeholder as Record<string, string>)
}, [placeholder, renderI18nObject])
const memorizedOptions = useMemo(() => {
return options?.map((option) => {
return {
label: typeof option.label === 'string' ? option.label : renderI18nObject(option.label),
value: option.value,
}
}) || []
}, [options, renderI18nObject])
const value = useStore(field.form.store, s => s.values[field.name])
const values = useStore(field.form.store, (s) => {
return show_on.reduce((acc, condition) => {
acc[condition.variable] = s.values[condition.variable]
return acc
}, {} as Record<string, any>)
})
const show = useMemo(() => {
return show_on.every((condition) => {
const conditionValue = values[condition.variable]
return conditionValue === condition.value
})
}, [values, show_on])
if (!show)
return null
return (
<div className={cn(fieldClassName)}>
<div className={cn(labelClassName, formLabelClassName)}>
{memorizedLabel}
{
required && !isValidElement(label) && (
<span className='ml-1 text-text-destructive-secondary'>*</span>
)
}
</div>
<div className={cn(inputContainerClassName)}>
{
formSchema.type === FormTypeEnum.textInput && (
<Input
id={field.name}
name={field.name}
className={cn(inputClassName)}
value={value || ''}
onChange={e => field.handleChange(e.target.value)}
onBlur={field.handleBlur}
disabled={disabled}
placeholder={memorizedPlaceholder}
/>
)
}
{
formSchema.type === FormTypeEnum.secretInput && (
<Input
id={field.name}
name={field.name}
type='password'
className={cn(inputClassName)}
value={value || ''}
onChange={e => field.handleChange(e.target.value)}
onBlur={field.handleBlur}
disabled={disabled}
placeholder={memorizedPlaceholder}
/>
)
}
{
formSchema.type === FormTypeEnum.textNumber && (
<Input
id={field.name}
name={field.name}
type='number'
className={cn(inputClassName)}
value={value || ''}
onChange={e => field.handleChange(e.target.value)}
onBlur={field.handleBlur}
disabled={disabled}
placeholder={memorizedPlaceholder}
/>
)
}
{
formSchema.type === FormTypeEnum.select && (
<PureSelect
value={value}
onChange={v => field.handleChange(v)}
disabled={disabled}
placeholder={memorizedPlaceholder}
options={memorizedOptions}
triggerPopupSameWidth
/>
)
}
{
formSchema.type === FormTypeEnum.radio && (
<div className='flex items-center space-x-2'>
{
memorizedOptions.map(option => (
<div
key={option.value}
className={cn(
'system-sm-regular hover:bg-components-option-card-option-hover-bg hover:border-components-option-card-option-hover-border flex h-8 grow cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg p-2 text-text-secondary',
value === option.value && 'border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary shadow-xs',
)}
onClick={() => field.handleChange(option.value)}
>
{option.label}
</div>
))
}
</div>
)
}
</div>
</div>
)
}
export default memo(BaseField)

View File

@@ -0,0 +1,115 @@
import {
memo,
useCallback,
useImperativeHandle,
} from 'react'
import type {
AnyFieldApi,
AnyFormApi,
} from '@tanstack/react-form'
import { useForm } from '@tanstack/react-form'
import type {
FormRef,
FormSchema,
} from '@/app/components/base/form/types'
import {
BaseField,
} from '.'
import type {
BaseFieldProps,
} from '.'
import cn from '@/utils/classnames'
import {
useGetFormValues,
useGetValidators,
} from '@/app/components/base/form/hooks'
export type BaseFormProps = {
formSchemas?: FormSchema[]
defaultValues?: Record<string, any>
formClassName?: string
ref?: FormRef
disabled?: boolean
formFromProps?: AnyFormApi
} & Pick<BaseFieldProps, 'fieldClassName' | 'labelClassName' | 'inputContainerClassName' | 'inputClassName'>
const BaseForm = ({
formSchemas = [],
defaultValues,
formClassName,
fieldClassName,
labelClassName,
inputContainerClassName,
inputClassName,
ref,
disabled,
formFromProps,
}: BaseFormProps) => {
const formFromHook = useForm({
defaultValues,
})
const form: any = formFromProps || formFromHook
const { getFormValues } = useGetFormValues(form, formSchemas)
const { getValidators } = useGetValidators()
useImperativeHandle(ref, () => {
return {
getForm() {
return form
},
getFormValues: (option) => {
return getFormValues(option)
},
}
}, [form, getFormValues])
const renderField = useCallback((field: AnyFieldApi) => {
const formSchema = formSchemas?.find(schema => schema.name === field.name)
if (formSchema) {
return (
<BaseField
field={field}
formSchema={formSchema}
fieldClassName={fieldClassName}
labelClassName={labelClassName}
inputContainerClassName={inputContainerClassName}
inputClassName={inputClassName}
disabled={disabled}
/>
)
}
return null
}, [formSchemas, fieldClassName, labelClassName, inputContainerClassName, inputClassName, disabled])
const renderFieldWrapper = useCallback((formSchema: FormSchema) => {
const validators = getValidators(formSchema)
const {
name,
} = formSchema
return (
<form.Field
key={name}
name={name}
validators={validators}
>
{renderField}
</form.Field>
)
}, [renderField, form, getValidators])
if (!formSchemas?.length)
return null
return (
<form
className={cn(formClassName)}
>
{formSchemas.map(renderFieldWrapper)}
</form>
)
}
export default memo(BaseForm)

View File

@@ -0,0 +1,2 @@
export { default as BaseForm, type BaseFormProps } from './base-form'
export { default as BaseField, type BaseFieldProps } from './base-field'

View File

@@ -0,0 +1,23 @@
import { memo } from 'react'
import { BaseForm } from '../../components/base'
import type { BaseFormProps } from '../../components/base'
const AuthForm = ({
formSchemas = [],
defaultValues,
ref,
formFromProps,
}: BaseFormProps) => {
return (
<BaseForm
ref={ref}
formSchemas={formSchemas}
defaultValues={defaultValues}
formClassName='space-y-4'
labelClassName='h-6 flex items-center mb-1 system-sm-medium text-text-secondary'
formFromProps={formFromProps}
/>
)
}
export default memo(AuthForm)

View File

@@ -0,0 +1,3 @@
export * from './use-check-validated'
export * from './use-get-form-values'
export * from './use-get-validators'

View File

@@ -0,0 +1,48 @@
import { useCallback } from 'react'
import type { AnyFormApi } from '@tanstack/react-form'
import { useToastContext } from '@/app/components/base/toast'
import type { FormSchema } from '@/app/components/base/form/types'
export const useCheckValidated = (form: AnyFormApi, FormSchemas: FormSchema[]) => {
const { notify } = useToastContext()
const checkValidated = useCallback(() => {
const allError = form?.getAllErrors()
const values = form.state.values
if (allError) {
const fields = allError.fields
const errorArray = Object.keys(fields).reduce((acc: string[], key: string) => {
const currentSchema = FormSchemas.find(schema => schema.name === key)
const { show_on = [] } = currentSchema || {}
const showOnValues = show_on.reduce((acc, condition) => {
acc[condition.variable] = values[condition.variable]
return acc
}, {} as Record<string, any>)
const show = show_on?.every((condition) => {
const conditionValue = showOnValues[condition.variable]
return conditionValue === condition.value
})
const errors: any[] = show ? fields[key].errors : []
return [...acc, ...errors]
}, [] as string[])
if (errorArray.length) {
notify({
type: 'error',
message: errorArray[0],
})
return false
}
return true
}
return true
}, [form, notify, FormSchemas])
return {
checkValidated,
}
}

View File

@@ -0,0 +1,44 @@
import { useCallback } from 'react'
import type { AnyFormApi } from '@tanstack/react-form'
import { useCheckValidated } from './use-check-validated'
import type {
FormSchema,
GetValuesOptions,
} from '../types'
import { getTransformedValuesWhenSecretInputPristine } from '../utils'
export const useGetFormValues = (form: AnyFormApi, formSchemas: FormSchema[]) => {
const { checkValidated } = useCheckValidated(form, formSchemas)
const getFormValues = useCallback((
{
needCheckValidatedValues,
needTransformWhenSecretFieldIsPristine,
}: GetValuesOptions,
) => {
const values = form?.store.state.values || {}
if (!needCheckValidatedValues) {
return {
values,
isCheckValidated: false,
}
}
if (checkValidated()) {
return {
values: needTransformWhenSecretFieldIsPristine ? getTransformedValuesWhenSecretInputPristine(formSchemas, form) : values,
isCheckValidated: true,
}
}
else {
return {
values: {},
isCheckValidated: false,
}
}
}, [form, checkValidated, formSchemas])
return {
getFormValues,
}
}

View File

@@ -0,0 +1,36 @@
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import type { FormSchema } from '../types'
export const useGetValidators = () => {
const { t } = useTranslation()
const getValidators = useCallback((formSchema: FormSchema) => {
const {
name,
validators,
required,
} = formSchema
let mergedValidators = validators
if (required && !validators) {
mergedValidators = {
onMount: ({ value }: any) => {
if (!value)
return t('common.errorMsg.fieldRequired', { field: name })
},
onChange: ({ value }: any) => {
if (!value)
return t('common.errorMsg.fieldRequired', { field: name })
},
onBlur: ({ value }: any) => {
if (!value)
return t('common.errorMsg.fieldRequired', { field: name })
},
}
}
return mergedValidators
}, [t])
return {
getValidators,
}
}

View File

@@ -0,0 +1,76 @@
import type {
ForwardedRef,
ReactNode,
} from 'react'
import type {
AnyFormApi,
FieldValidators,
} from '@tanstack/react-form'
export type TypeWithI18N<T = string> = {
en_US: T
zh_Hans: T
[key: string]: T
}
export type FormShowOnObject = {
variable: string
value: string
}
export enum FormTypeEnum {
textInput = 'text-input',
textNumber = 'number-input',
secretInput = 'secret-input',
select = 'select',
radio = 'radio',
boolean = 'boolean',
files = 'files',
file = 'file',
modelSelector = 'model-selector',
toolSelector = 'tool-selector',
multiToolSelector = 'array[tools]',
appSelector = 'app-selector',
dynamicSelect = 'dynamic-select',
}
export type FormOption = {
label: TypeWithI18N | string
value: string
show_on?: FormShowOnObject[]
icon?: string
}
export type AnyValidators = FieldValidators<any, any, any, any, any, any, any, any, any, any>
export type FormSchema = {
type: FormTypeEnum
name: string
label: string | ReactNode | TypeWithI18N
required: boolean
default?: any
tooltip?: string | TypeWithI18N
show_on?: FormShowOnObject[]
url?: string
scope?: string
help?: string | TypeWithI18N
placeholder?: string | TypeWithI18N
options?: FormOption[]
labelClassName?: string
validators?: AnyValidators
}
export type FormValues = Record<string, any>
export type GetValuesOptions = {
needTransformWhenSecretFieldIsPristine?: boolean
needCheckValidatedValues?: boolean
}
export type FormRefObject = {
getForm: () => AnyFormApi
getFormValues: (obj: GetValuesOptions) => {
values: Record<string, any>
isCheckValidated: boolean
}
}
export type FormRef = ForwardedRef<FormRefObject>

View File

@@ -0,0 +1 @@
export * from './secret-input'

View File

@@ -0,0 +1,29 @@
import type { AnyFormApi } from '@tanstack/react-form'
import type { FormSchema } from '@/app/components/base/form/types'
import { FormTypeEnum } from '@/app/components/base/form/types'
export const transformFormSchemasSecretInput = (isPristineSecretInputNames: string[], values: Record<string, any>) => {
const transformedValues: Record<string, any> = { ...values }
isPristineSecretInputNames.forEach((name) => {
if (transformedValues[name])
transformedValues[name] = '[__HIDDEN__]'
})
return transformedValues
}
export const getTransformedValuesWhenSecretInputPristine = (formSchemas: FormSchema[], form: AnyFormApi) => {
const values = form?.store.state.values || {}
const isPristineSecretInputNames: string[] = []
for (let i = 0; i < formSchemas.length; i++) {
const schema = formSchemas[i]
if (schema.type === FormTypeEnum.secretInput) {
const fieldMeta = form?.getFieldMeta(schema.name)
if (fieldMeta?.isPristine)
isPristineSecretInputNames.push(schema.name)
}
}
return transformFormSchemasSecretInput(isPristineSecretInputNames, values)
}

View File

@@ -0,0 +1,127 @@
import { memo } from 'react'
import { useTranslation } from 'react-i18next'
import { RiCloseLine } from '@remixicon/react'
import {
PortalToFollowElem,
PortalToFollowElemContent,
} from '@/app/components/base/portal-to-follow-elem'
import Button from '@/app/components/base/button'
import type { ButtonProps } from '@/app/components/base/button'
import cn from '@/utils/classnames'
type ModalProps = {
onClose?: () => void
size?: 'sm' | 'md'
title: string
subTitle?: string
children?: React.ReactNode
confirmButtonText?: string
onConfirm?: () => void
cancelButtonText?: string
onCancel?: () => void
showExtraButton?: boolean
extraButtonText?: string
extraButtonVariant?: ButtonProps['variant']
onExtraButtonClick?: () => void
footerSlot?: React.ReactNode
bottomSlot?: React.ReactNode
disabled?: boolean
}
const Modal = ({
onClose,
size = 'sm',
title,
subTitle,
children,
confirmButtonText,
onConfirm,
cancelButtonText,
onCancel,
showExtraButton,
extraButtonVariant = 'warning',
extraButtonText,
onExtraButtonClick,
footerSlot,
bottomSlot,
disabled,
}: ModalProps) => {
const { t } = useTranslation()
return (
<PortalToFollowElem open>
<PortalToFollowElemContent
className='z-[9998] flex h-full w-full items-center justify-center bg-background-overlay'
onClick={onClose}
>
<div
className={cn(
'max-h-[80%] w-[480px] overflow-y-auto rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-xs',
size === 'sm' && 'w-[480px',
size === 'md' && 'w-[640px]',
)}
onClick={e => e.stopPropagation()}
>
<div className='title-2xl-semi-bold relative p-6 pb-3 pr-14 text-text-primary'>
{title}
{
subTitle && (
<div className='system-xs-regular mt-1 text-text-tertiary'>
{subTitle}
</div>
)
}
<div
className='absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg'
onClick={onClose}
>
<RiCloseLine className='h-5 w-5 text-text-tertiary' />
</div>
</div>
{
children && (
<div className='px-6 py-3'>{children}</div>
)
}
<div className='flex justify-between p-6 pt-5'>
<div>
{footerSlot}
</div>
<div className='flex items-center'>
{
showExtraButton && (
<>
<Button
variant={extraButtonVariant}
onClick={onExtraButtonClick}
disabled={disabled}
>
{extraButtonText || t('common.operation.remove')}
</Button>
<div className='mx-3 h-4 w-[1px] bg-divider-regular'></div>
</>
)
}
<Button
onClick={onCancel}
disabled={disabled}
>
{cancelButtonText || t('common.operation.cancel')}
</Button>
<Button
className='ml-2'
variant='primary'
onClick={onConfirm}
disabled={disabled}
>
{confirmButtonText || t('common.operation.save')}
</Button>
</div>
</div>
{bottomSlot}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
)
}
export default memo(Modal)

View File

@@ -39,6 +39,9 @@ type PureSelectProps = {
itemClassName?: string
title?: string
},
placeholder?: string
disabled?: boolean
triggerPopupSameWidth?: boolean
}
const PureSelect = ({
options,
@@ -47,6 +50,9 @@ const PureSelect = ({
containerProps,
triggerProps,
popupProps,
placeholder,
disabled,
triggerPopupSameWidth,
}: PureSelectProps) => {
const { t } = useTranslation()
const {
@@ -74,7 +80,7 @@ const PureSelect = ({
}, [onOpenChange])
const selectedOption = options.find(option => option.value === value)
const triggerText = selectedOption?.label || t('common.placeholder.select')
const triggerText = selectedOption?.label || placeholder || t('common.placeholder.select')
return (
<PortalToFollowElem
@@ -82,6 +88,7 @@ const PureSelect = ({
offset={offset || 4}
open={mergedOpen}
onOpenChange={handleOpenChange}
triggerPopupSameWidth={triggerPopupSameWidth}
>
<PortalToFollowElemTrigger
onClick={() => handleOpenChange(!mergedOpen)}
@@ -135,6 +142,7 @@ const PureSelect = ({
)}
title={option.label}
onClick={() => {
if (disabled) return
onChange?.(option.value)
handleOpenChange(false)
}}

View File

@@ -0,0 +1,50 @@
import {
memo,
useState,
} from 'react'
import Button from '@/app/components/base/button'
import type { ButtonProps } from '@/app/components/base/button'
import ApiKeyModal from './api-key-modal'
import type { PluginPayload } from '../types'
export type AddApiKeyButtonProps = {
pluginPayload: PluginPayload
buttonVariant?: ButtonProps['variant']
buttonText?: string
disabled?: boolean
onUpdate?: () => void
}
const AddApiKeyButton = ({
pluginPayload,
buttonVariant = 'secondary-accent',
buttonText = 'use api key',
disabled,
onUpdate,
}: AddApiKeyButtonProps) => {
const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false)
return (
<>
<Button
className='w-full'
variant={buttonVariant}
onClick={() => setIsApiKeyModalOpen(true)}
disabled={disabled}
>
{buttonText}
</Button>
{
isApiKeyModalOpen && (
<ApiKeyModal
pluginPayload={pluginPayload}
onClose={() => setIsApiKeyModalOpen(false)}
onUpdate={onUpdate}
/>
)
}
</>
)
}
export default memo(AddApiKeyButton)

View File

@@ -0,0 +1,259 @@
import {
memo,
useCallback,
useMemo,
useState,
} from 'react'
import { useTranslation } from 'react-i18next'
import {
RiClipboardLine,
RiEqualizer2Line,
RiInformation2Fill,
} from '@remixicon/react'
import Button from '@/app/components/base/button'
import type { ButtonProps } from '@/app/components/base/button'
import OAuthClientSettings from './oauth-client-settings'
import cn from '@/utils/classnames'
import type { PluginPayload } from '../types'
import { openOAuthPopup } from '@/hooks/use-oauth'
import Badge from '@/app/components/base/badge'
import {
useGetPluginOAuthClientSchemaHook,
useGetPluginOAuthUrlHook,
} from '../hooks/use-credential'
import type { FormSchema } from '@/app/components/base/form/types'
import { FormTypeEnum } from '@/app/components/base/form/types'
import ActionButton from '@/app/components/base/action-button'
import { useRenderI18nObject } from '@/hooks/use-i18n'
export type AddOAuthButtonProps = {
pluginPayload: PluginPayload
buttonVariant?: ButtonProps['variant']
buttonText?: string
className?: string
buttonLeftClassName?: string
buttonRightClassName?: string
dividerClassName?: string
disabled?: boolean
onUpdate?: () => void
}
const AddOAuthButton = ({
pluginPayload,
buttonVariant = 'primary',
buttonText = 'use oauth',
className,
buttonLeftClassName,
buttonRightClassName,
dividerClassName,
disabled,
onUpdate,
}: AddOAuthButtonProps) => {
const { t } = useTranslation()
const renderI18nObject = useRenderI18nObject()
const [isOAuthSettingsOpen, setIsOAuthSettingsOpen] = useState(false)
const { mutateAsync: getPluginOAuthUrl } = useGetPluginOAuthUrlHook(pluginPayload)
const { data, isLoading } = useGetPluginOAuthClientSchemaHook(pluginPayload)
const {
schema = [],
is_oauth_custom_client_enabled,
is_system_oauth_params_exists,
client_params,
redirect_uri,
} = data || {}
const isConfigured = is_system_oauth_params_exists || is_oauth_custom_client_enabled
const handleOAuth = useCallback(async () => {
const { authorization_url } = await getPluginOAuthUrl()
if (authorization_url) {
openOAuthPopup(
authorization_url,
() => onUpdate?.(),
)
}
}, [getPluginOAuthUrl, onUpdate])
const renderCustomLabel = useCallback((item: FormSchema) => {
return (
<div className='w-full'>
<div className='mb-4 flex rounded-xl bg-background-section-burn p-4'>
<div className='mr-3 flex h-9 w-9 shrink-0 items-center justify-center rounded-lg border-[0.5px] border-components-card-border bg-components-card-bg shadow-lg'>
<RiInformation2Fill className='h-5 w-5 text-text-accent' />
</div>
<div className='w-0 grow'>
<div className='system-sm-regular mb-1.5'>
{t('plugin.auth.clientInfo')}
</div>
{
redirect_uri && (
<div className='system-sm-medium flex w-full py-0.5'>
<div className='w-0 grow break-words'>{redirect_uri}</div>
<ActionButton
className='shrink-0'
onClick={() => {
navigator.clipboard.writeText(redirect_uri || '')
}}
>
<RiClipboardLine className='h-4 w-4' />
</ActionButton>
</div>
)
}
</div>
</div>
<div className='system-sm-medium flex h-6 items-center text-text-secondary'>
{renderI18nObject(item.label as Record<string, string>)}
{
item.required && (
<span className='ml-1 text-text-destructive-secondary'>*</span>
)
}
</div>
</div>
)
}, [t, redirect_uri, renderI18nObject])
const memorizedSchemas = useMemo(() => {
const result: FormSchema[] = schema.map((item, index) => {
return {
...item,
label: index === 0 ? renderCustomLabel(item) : item.label,
labelClassName: index === 0 ? 'h-auto' : undefined,
}
})
if (is_system_oauth_params_exists) {
result.unshift({
name: '__oauth_client__',
label: t('plugin.auth.oauthClient'),
type: FormTypeEnum.radio,
options: [
{
label: t('plugin.auth.default'),
value: 'default',
},
{
label: t('plugin.auth.custom'),
value: 'custom',
},
],
required: false,
default: is_oauth_custom_client_enabled ? 'custom' : 'default',
} as FormSchema)
result.forEach((item, index) => {
if (index > 0) {
item.show_on = [
{
variable: '__oauth_client__',
value: 'custom',
},
]
if (client_params)
item.default = client_params[item.name] || item.default
}
})
}
return result
}, [schema, renderCustomLabel, t, is_system_oauth_params_exists, is_oauth_custom_client_enabled, client_params])
const __auth_client__ = useMemo(() => {
if (isConfigured) {
if (is_oauth_custom_client_enabled)
return 'custom'
return 'default'
}
else {
if (is_system_oauth_params_exists)
return 'default'
return 'custom'
}
}, [isConfigured, is_oauth_custom_client_enabled, is_system_oauth_params_exists])
return (
<>
{
isConfigured && (
<Button
variant={buttonVariant}
className={cn(
'w-full px-0 py-0 hover:bg-components-button-primary-bg',
className,
)}
disabled={disabled}
onClick={handleOAuth}
>
<div className={cn(
'flex h-full w-0 grow items-center justify-center rounded-l-lg pl-0.5 hover:bg-components-button-primary-bg-hover',
buttonLeftClassName,
)}>
<div
className='truncate'
title={buttonText}
>
{buttonText}
</div>
{
is_oauth_custom_client_enabled && (
<Badge
className={cn(
'ml-1 mr-0.5',
buttonVariant === 'primary' && 'border-text-primary-on-surface bg-components-badge-bg-dimm text-text-primary-on-surface',
)}
>
{t('plugin.auth.custom')}
</Badge>
)
}
</div>
<div className={cn(
'h-4 w-[1px] shrink-0 bg-text-primary-on-surface opacity-[0.15]',
dividerClassName,
)}></div>
<div
className={cn(
'flex h-full w-8 shrink-0 items-center justify-center rounded-r-lg hover:bg-components-button-primary-bg-hover',
buttonRightClassName,
)}
onClick={(e) => {
e.stopPropagation()
setIsOAuthSettingsOpen(true)
}}
>
<RiEqualizer2Line className='h-4 w-4' />
</div>
</Button>
)
}
{
!isConfigured && (
<Button
variant={buttonVariant}
onClick={() => setIsOAuthSettingsOpen(true)}
disabled={disabled}
className='w-full'
>
<RiEqualizer2Line className='mr-0.5 h-4 w-4' />
{t('plugin.auth.setupOAuth')}
</Button>
)
}
{
isOAuthSettingsOpen && (
<OAuthClientSettings
pluginPayload={pluginPayload}
onClose={() => setIsOAuthSettingsOpen(false)}
disabled={disabled || isLoading}
schemas={memorizedSchemas}
onAuth={handleOAuth}
editValues={{
...client_params,
__oauth_client__: __auth_client__,
}}
hasOriginalClientParams={Object.keys(client_params || {}).length > 0}
onUpdate={onUpdate}
/>
)
}
</>
)
}
export default memo(AddOAuthButton)

View File

@@ -0,0 +1,181 @@
import {
memo,
useCallback,
useMemo,
useRef,
useState,
} from 'react'
import { useTranslation } from 'react-i18next'
import { RiExternalLinkLine } from '@remixicon/react'
import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security'
import Modal from '@/app/components/base/modal/modal'
import { CredentialTypeEnum } from '../types'
import AuthForm from '@/app/components/base/form/form-scenarios/auth'
import type { FormRefObject } from '@/app/components/base/form/types'
import { FormTypeEnum } from '@/app/components/base/form/types'
import { useToastContext } from '@/app/components/base/toast'
import Loading from '@/app/components/base/loading'
import type { PluginPayload } from '../types'
import {
useAddPluginCredentialHook,
useGetPluginCredentialSchemaHook,
useUpdatePluginCredentialHook,
} from '../hooks/use-credential'
import { useRenderI18nObject } from '@/hooks/use-i18n'
export type ApiKeyModalProps = {
pluginPayload: PluginPayload
onClose?: () => void
editValues?: Record<string, any>
onRemove?: () => void
disabled?: boolean
onUpdate?: () => void
}
const ApiKeyModal = ({
pluginPayload,
onClose,
editValues,
onRemove,
disabled,
onUpdate,
}: ApiKeyModalProps) => {
const { t } = useTranslation()
const { notify } = useToastContext()
const [doingAction, setDoingAction] = useState(false)
const doingActionRef = useRef(doingAction)
const handleSetDoingAction = useCallback((value: boolean) => {
doingActionRef.current = value
setDoingAction(value)
}, [])
const { data = [], isLoading } = useGetPluginCredentialSchemaHook(pluginPayload, CredentialTypeEnum.API_KEY)
const formSchemas = useMemo(() => {
return [
{
type: FormTypeEnum.textInput,
name: '__name__',
label: t('plugin.auth.authorizationName'),
required: false,
},
...data,
]
}, [data, t])
const defaultValues = formSchemas.reduce((acc, schema) => {
if (schema.default)
acc[schema.name] = schema.default
return acc
}, {} as Record<string, any>)
const helpField = formSchemas.find(schema => schema.url && schema.help)
const renderI18nObject = useRenderI18nObject()
const { mutateAsync: addPluginCredential } = useAddPluginCredentialHook(pluginPayload)
const { mutateAsync: updatePluginCredential } = useUpdatePluginCredentialHook(pluginPayload)
const formRef = useRef<FormRefObject>(null)
const handleConfirm = useCallback(async () => {
if (doingActionRef.current)
return
const {
isCheckValidated,
values,
} = formRef.current?.getFormValues({
needCheckValidatedValues: true,
needTransformWhenSecretFieldIsPristine: true,
}) || { isCheckValidated: false, values: {} }
if (!isCheckValidated)
return
try {
const {
__name__,
__credential_id__,
...restValues
} = values
handleSetDoingAction(true)
if (editValues) {
await updatePluginCredential({
credentials: restValues,
credential_id: __credential_id__,
name: __name__ || '',
})
}
else {
await addPluginCredential({
credentials: restValues,
type: CredentialTypeEnum.API_KEY,
name: __name__ || '',
})
}
notify({
type: 'success',
message: t('common.api.actionSuccess'),
})
onClose?.()
onUpdate?.()
}
finally {
handleSetDoingAction(false)
}
}, [addPluginCredential, onClose, onUpdate, updatePluginCredential, notify, t, editValues, handleSetDoingAction])
return (
<Modal
size='md'
title={t('plugin.auth.useApiAuth')}
subTitle={t('plugin.auth.useApiAuthDesc')}
onClose={onClose}
onCancel={onClose}
footerSlot={
helpField && (
<a
className='system-xs-regular mr-2 flex items-center py-2 text-text-accent'
href={helpField?.url}
target='_blank'
>
<span className='break-all'>
{renderI18nObject(helpField?.help as any)}
</span>
<RiExternalLinkLine className='ml-1 h-3 w-3' />
</a>
)
}
bottomSlot={
<div className='flex items-center justify-center bg-background-section-burn py-3 text-xs text-text-tertiary'>
<Lock01 className='mr-1 h-3 w-3 text-text-tertiary' />
{t('common.modelProvider.encrypted.front')}
<a
className='mx-1 text-text-accent'
target='_blank' rel='noopener noreferrer'
href='https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html'
>
PKCS1_OAEP
</a>
{t('common.modelProvider.encrypted.back')}
</div>
}
onConfirm={handleConfirm}
showExtraButton={!!editValues}
onExtraButtonClick={onRemove}
disabled={disabled || isLoading || doingAction}
>
{
isLoading && (
<div className='flex h-40 items-center justify-center'>
<Loading />
</div>
)
}
{
!isLoading && !!data.length && (
<AuthForm
ref={formRef}
formSchemas={formSchemas}
defaultValues={editValues || defaultValues}
disabled={disabled}
/>
)
}
</Modal>
)
}
export default memo(ApiKeyModal)

View File

@@ -0,0 +1,104 @@
import {
memo,
useMemo,
} from 'react'
import { useTranslation } from 'react-i18next'
import AddOAuthButton from './add-oauth-button'
import type { AddOAuthButtonProps } from './add-oauth-button'
import AddApiKeyButton from './add-api-key-button'
import type { AddApiKeyButtonProps } from './add-api-key-button'
import type { PluginPayload } from '../types'
type AuthorizeProps = {
pluginPayload: PluginPayload
theme?: 'primary' | 'secondary'
showDivider?: boolean
canOAuth?: boolean
canApiKey?: boolean
disabled?: boolean
onUpdate?: () => void
}
const Authorize = ({
pluginPayload,
theme = 'primary',
showDivider = true,
canOAuth,
canApiKey,
disabled,
onUpdate,
}: AuthorizeProps) => {
const { t } = useTranslation()
const oAuthButtonProps: AddOAuthButtonProps = useMemo(() => {
if (theme === 'secondary') {
return {
buttonText: !canApiKey ? t('plugin.auth.useOAuthAuth') : t('plugin.auth.addOAuth'),
buttonVariant: 'secondary',
className: 'hover:bg-components-button-secondary-bg',
buttonLeftClassName: 'hover:bg-components-button-secondary-bg-hover',
buttonRightClassName: 'hover:bg-components-button-secondary-bg-hover',
dividerClassName: 'bg-divider-regular opacity-100',
pluginPayload,
}
}
return {
buttonText: !canApiKey ? t('plugin.auth.useOAuthAuth') : t('plugin.auth.addOAuth'),
pluginPayload,
}
}, [canApiKey, theme, pluginPayload, t])
const apiKeyButtonProps: AddApiKeyButtonProps = useMemo(() => {
if (theme === 'secondary') {
return {
pluginPayload,
buttonVariant: 'secondary',
buttonText: !canOAuth ? t('plugin.auth.useApiAuth') : t('plugin.auth.addApi'),
}
}
return {
pluginPayload,
buttonText: !canOAuth ? t('plugin.auth.useApiAuth') : t('plugin.auth.addApi'),
buttonVariant: !canOAuth ? 'primary' : 'secondary-accent',
}
}, [canOAuth, theme, pluginPayload, t])
return (
<>
<div className='flex items-center space-x-1.5'>
{
canOAuth && (
<div className='min-w-0 flex-[1]'>
<AddOAuthButton
{...oAuthButtonProps}
disabled={disabled}
onUpdate={onUpdate}
/>
</div>
)
}
{
showDivider && canOAuth && canApiKey && (
<div className='system-2xs-medium-uppercase flex shrink-0 flex-col items-center justify-between text-text-tertiary'>
<div className='h-2 w-[1px] bg-divider-subtle'></div>
or
<div className='h-2 w-[1px] bg-divider-subtle'></div>
</div>
)
}
{
canApiKey && (
<div className='min-w-0 flex-[1]'>
<AddApiKeyButton
{...apiKeyButtonProps}
disabled={disabled}
onUpdate={onUpdate}
/>
</div>
)
}
</div>
</>
)
}
export default memo(Authorize)

View File

@@ -0,0 +1,188 @@
import {
memo,
useCallback,
useRef,
useState,
} from 'react'
import { RiExternalLinkLine } from '@remixicon/react'
import {
useForm,
useStore,
} from '@tanstack/react-form'
import { useTranslation } from 'react-i18next'
import Modal from '@/app/components/base/modal/modal'
import {
useDeletePluginOAuthCustomClientHook,
useInvalidPluginOAuthClientSchemaHook,
useSetPluginOAuthCustomClientHook,
} from '../hooks/use-credential'
import type { PluginPayload } from '../types'
import AuthForm from '@/app/components/base/form/form-scenarios/auth'
import type {
FormRefObject,
FormSchema,
} from '@/app/components/base/form/types'
import { useToastContext } from '@/app/components/base/toast'
import Button from '@/app/components/base/button'
import { useRenderI18nObject } from '@/hooks/use-i18n'
type OAuthClientSettingsProps = {
pluginPayload: PluginPayload
onClose?: () => void
editValues?: Record<string, any>
disabled?: boolean
schemas: FormSchema[]
onAuth?: () => Promise<void>
hasOriginalClientParams?: boolean
onUpdate?: () => void
}
const OAuthClientSettings = ({
pluginPayload,
onClose,
editValues,
disabled,
schemas,
onAuth,
hasOriginalClientParams,
onUpdate,
}: OAuthClientSettingsProps) => {
const { t } = useTranslation()
const { notify } = useToastContext()
const [doingAction, setDoingAction] = useState(false)
const doingActionRef = useRef(doingAction)
const handleSetDoingAction = useCallback((value: boolean) => {
doingActionRef.current = value
setDoingAction(value)
}, [])
const defaultValues = schemas.reduce((acc, schema) => {
if (schema.default)
acc[schema.name] = schema.default
return acc
}, {} as Record<string, any>)
const { mutateAsync: setPluginOAuthCustomClient } = useSetPluginOAuthCustomClientHook(pluginPayload)
const invalidPluginOAuthClientSchema = useInvalidPluginOAuthClientSchemaHook(pluginPayload)
const formRef = useRef<FormRefObject>(null)
const handleConfirm = useCallback(async () => {
if (doingActionRef.current)
return
try {
const {
isCheckValidated,
values,
} = formRef.current?.getFormValues({
needCheckValidatedValues: true,
needTransformWhenSecretFieldIsPristine: true,
}) || { isCheckValidated: false, values: {} }
if (!isCheckValidated)
throw new Error('error')
const {
__oauth_client__,
...restValues
} = values
handleSetDoingAction(true)
await setPluginOAuthCustomClient({
client_params: restValues,
enable_oauth_custom_client: __oauth_client__ === 'custom',
})
notify({
type: 'success',
message: t('common.api.actionSuccess'),
})
onClose?.()
onUpdate?.()
invalidPluginOAuthClientSchema()
}
finally {
handleSetDoingAction(false)
}
}, [onClose, onUpdate, invalidPluginOAuthClientSchema, setPluginOAuthCustomClient, notify, t, handleSetDoingAction])
const handleConfirmAndAuthorize = useCallback(async () => {
await handleConfirm()
if (onAuth)
await onAuth()
}, [handleConfirm, onAuth])
const { mutateAsync: deletePluginOAuthCustomClient } = useDeletePluginOAuthCustomClientHook(pluginPayload)
const handleRemove = useCallback(async () => {
if (doingActionRef.current)
return
try {
handleSetDoingAction(true)
await deletePluginOAuthCustomClient()
notify({
type: 'success',
message: t('common.api.actionSuccess'),
})
onClose?.()
onUpdate?.()
invalidPluginOAuthClientSchema()
}
finally {
handleSetDoingAction(false)
}
}, [onUpdate, invalidPluginOAuthClientSchema, deletePluginOAuthCustomClient, notify, t, handleSetDoingAction, onClose])
const form = useForm({
defaultValues: editValues || defaultValues,
})
const __oauth_client__ = useStore(form.store, s => s.values.__oauth_client__)
const helpField = schemas.find(schema => schema.url && schema.help)
const renderI18nObject = useRenderI18nObject()
return (
<Modal
title={t('plugin.auth.oauthClientSettings')}
confirmButtonText={t('plugin.auth.saveAndAuth')}
cancelButtonText={t('plugin.auth.saveOnly')}
extraButtonText={t('common.operation.cancel')}
showExtraButton
extraButtonVariant='secondary'
onExtraButtonClick={onClose}
onClose={onClose}
onCancel={handleConfirm}
onConfirm={handleConfirmAndAuthorize}
disabled={disabled || doingAction}
footerSlot={
__oauth_client__ === 'custom' && hasOriginalClientParams && (
<div className='grow'>
<Button
variant='secondary'
className='text-components-button-destructive-secondary-text'
disabled={disabled || doingAction || !editValues}
onClick={handleRemove}
>
{t('common.operation.remove')}
</Button>
</div>
)
}
>
<>
<AuthForm
formFromProps={form}
ref={formRef}
formSchemas={schemas}
defaultValues={editValues || defaultValues}
disabled={disabled}
/>
{
helpField && __oauth_client__ === 'custom' && (
<a
className='system-xs-regular mt-4 flex items-center text-text-accent'
href={helpField?.url}
target='_blank'
>
<span className='break-all'>
{renderI18nObject(helpField?.help as any)}
</span>
<RiExternalLinkLine className='ml-1 h-3 w-3' />
</a>
)}
</>
</Modal>
)
}
export default memo(OAuthClientSettings)

View File

@@ -0,0 +1,113 @@
import {
memo,
useCallback,
useState,
} from 'react'
import { useTranslation } from 'react-i18next'
import { RiArrowDownSLine } from '@remixicon/react'
import Button from '@/app/components/base/button'
import Indicator from '@/app/components/header/indicator'
import cn from '@/utils/classnames'
import type {
Credential,
PluginPayload,
} from './types'
import {
Authorized,
usePluginAuth,
} from '.'
type AuthorizedInNodeProps = {
pluginPayload: PluginPayload
onAuthorizationItemClick: (id: string) => void
credentialId?: string
}
const AuthorizedInNode = ({
pluginPayload,
onAuthorizationItemClick,
credentialId,
}: AuthorizedInNodeProps) => {
const { t } = useTranslation()
const [isOpen, setIsOpen] = useState(false)
const {
canApiKey,
canOAuth,
credentials,
disabled,
invalidPluginCredentialInfo,
} = usePluginAuth(pluginPayload, isOpen || !!credentialId)
const renderTrigger = useCallback((open?: boolean) => {
let label = ''
let removed = false
if (!credentialId) {
label = t('plugin.auth.workspaceDefault')
}
else {
const credential = credentials.find(c => c.id === credentialId)
label = credential ? credential.name : t('plugin.auth.authRemoved')
removed = !credential
}
return (
<Button
size='small'
className={cn(
open && !removed && 'bg-components-button-ghost-bg-hover',
removed && 'bg-transparent text-text-destructive',
)}
>
<Indicator
className='mr-1.5'
color={removed ? 'red' : 'green'}
/>
{label}
<RiArrowDownSLine
className={cn(
'h-3.5 w-3.5 text-components-button-ghost-text',
removed && 'text-text-destructive',
)}
/>
</Button>
)
}, [credentialId, credentials, t])
const extraAuthorizationItems: Credential[] = [
{
id: '__workspace_default__',
name: t('plugin.auth.workspaceDefault'),
provider: '',
is_default: !credentialId,
isWorkspaceDefault: true,
},
]
const handleAuthorizationItemClick = useCallback((id: string) => {
onAuthorizationItemClick(id)
setIsOpen(false)
}, [
onAuthorizationItemClick,
setIsOpen,
])
return (
<Authorized
pluginPayload={pluginPayload}
credentials={credentials}
canOAuth={canOAuth}
canApiKey={canApiKey}
renderTrigger={renderTrigger}
isOpen={isOpen}
onOpenChange={setIsOpen}
offset={4}
placement='bottom-end'
triggerPopupSameWidth={false}
popupClassName='w-[360px]'
disabled={disabled}
disableSetDefault
onItemClick={handleAuthorizationItemClick}
extraAuthorizationItems={extraAuthorizationItems}
showItemSelectedIcon
selectedCredentialId={credentialId || '__workspace_default__'}
onUpdate={invalidPluginCredentialInfo}
/>
)
}
export default memo(AuthorizedInNode)

View File

@@ -0,0 +1,342 @@
import {
memo,
useCallback,
useRef,
useState,
} from 'react'
import {
RiArrowDownSLine,
} from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import type {
PortalToFollowElemOptions,
} from '@/app/components/base/portal-to-follow-elem'
import Button from '@/app/components/base/button'
import Indicator from '@/app/components/header/indicator'
import cn from '@/utils/classnames'
import Confirm from '@/app/components/base/confirm'
import Authorize from '../authorize'
import type { Credential } from '../types'
import { CredentialTypeEnum } from '../types'
import ApiKeyModal from '../authorize/api-key-modal'
import Item from './item'
import { useToastContext } from '@/app/components/base/toast'
import type { PluginPayload } from '../types'
import {
useDeletePluginCredentialHook,
useSetPluginDefaultCredentialHook,
useUpdatePluginCredentialHook,
} from '../hooks/use-credential'
type AuthorizedProps = {
pluginPayload: PluginPayload
credentials: Credential[]
canOAuth?: boolean
canApiKey?: boolean
disabled?: boolean
renderTrigger?: (open?: boolean) => React.ReactNode
isOpen?: boolean
onOpenChange?: (open: boolean) => void
offset?: PortalToFollowElemOptions['offset']
placement?: PortalToFollowElemOptions['placement']
triggerPopupSameWidth?: boolean
popupClassName?: string
disableSetDefault?: boolean
onItemClick?: (id: string) => void
extraAuthorizationItems?: Credential[]
showItemSelectedIcon?: boolean
selectedCredentialId?: string
onUpdate?: () => void
}
const Authorized = ({
pluginPayload,
credentials,
canOAuth,
canApiKey,
disabled,
renderTrigger,
isOpen,
onOpenChange,
offset = 8,
placement = 'bottom-start',
triggerPopupSameWidth = true,
popupClassName,
disableSetDefault,
onItemClick,
extraAuthorizationItems,
showItemSelectedIcon,
selectedCredentialId,
onUpdate,
}: AuthorizedProps) => {
const { t } = useTranslation()
const { notify } = useToastContext()
const [isLocalOpen, setIsLocalOpen] = useState(false)
const mergedIsOpen = isOpen ?? isLocalOpen
const setMergedIsOpen = useCallback((open: boolean) => {
if (onOpenChange)
onOpenChange(open)
setIsLocalOpen(open)
}, [onOpenChange])
const oAuthCredentials = credentials.filter(credential => credential.credential_type === CredentialTypeEnum.OAUTH2)
const apiKeyCredentials = credentials.filter(credential => credential.credential_type === CredentialTypeEnum.API_KEY)
const pendingOperationCredentialId = useRef<string | null>(null)
const [deleteCredentialId, setDeleteCredentialId] = useState<string | null>(null)
const { mutateAsync: deletePluginCredential } = useDeletePluginCredentialHook(pluginPayload)
const openConfirm = useCallback((credentialId?: string) => {
if (credentialId)
pendingOperationCredentialId.current = credentialId
setDeleteCredentialId(pendingOperationCredentialId.current)
}, [])
const closeConfirm = useCallback(() => {
setDeleteCredentialId(null)
pendingOperationCredentialId.current = null
}, [])
const [doingAction, setDoingAction] = useState(false)
const doingActionRef = useRef(doingAction)
const handleSetDoingAction = useCallback((doing: boolean) => {
doingActionRef.current = doing
setDoingAction(doing)
}, [])
const handleConfirm = useCallback(async () => {
if (doingActionRef.current)
return
if (!pendingOperationCredentialId.current) {
setDeleteCredentialId(null)
return
}
try {
handleSetDoingAction(true)
await deletePluginCredential({ credential_id: pendingOperationCredentialId.current })
notify({
type: 'success',
message: t('common.api.actionSuccess'),
})
onUpdate?.()
setDeleteCredentialId(null)
pendingOperationCredentialId.current = null
}
finally {
handleSetDoingAction(false)
}
}, [deletePluginCredential, onUpdate, notify, t, handleSetDoingAction])
const [editValues, setEditValues] = useState<Record<string, any> | null>(null)
const handleEdit = useCallback((id: string, values: Record<string, any>) => {
pendingOperationCredentialId.current = id
setEditValues(values)
}, [])
const handleRemove = useCallback(() => {
setDeleteCredentialId(pendingOperationCredentialId.current)
}, [])
const { mutateAsync: setPluginDefaultCredential } = useSetPluginDefaultCredentialHook(pluginPayload)
const handleSetDefault = useCallback(async (id: string) => {
if (doingActionRef.current)
return
try {
handleSetDoingAction(true)
await setPluginDefaultCredential(id)
notify({
type: 'success',
message: t('common.api.actionSuccess'),
})
onUpdate?.()
}
finally {
handleSetDoingAction(false)
}
}, [setPluginDefaultCredential, onUpdate, notify, t, handleSetDoingAction])
const { mutateAsync: updatePluginCredential } = useUpdatePluginCredentialHook(pluginPayload)
const handleRename = useCallback(async (payload: {
credential_id: string
name: string
}) => {
if (doingActionRef.current)
return
try {
handleSetDoingAction(true)
await updatePluginCredential(payload)
notify({
type: 'success',
message: t('common.api.actionSuccess'),
})
onUpdate?.()
}
finally {
handleSetDoingAction(false)
}
}, [updatePluginCredential, notify, t, handleSetDoingAction, onUpdate])
return (
<>
<PortalToFollowElem
open={mergedIsOpen}
onOpenChange={setMergedIsOpen}
placement={placement}
offset={offset}
triggerPopupSameWidth={triggerPopupSameWidth}
>
<PortalToFollowElemTrigger
onClick={() => setMergedIsOpen(!mergedIsOpen)}
asChild
>
{
renderTrigger
? renderTrigger(mergedIsOpen)
: (
<Button
className={cn(
'w-full',
isOpen && 'bg-components-button-secondary-bg-hover',
)}>
<Indicator className='mr-2' />
{credentials.length}&nbsp;
{
credentials.length > 1
? t('plugin.auth.authorizations')
: t('plugin.auth.authorization')
}
<RiArrowDownSLine className='ml-0.5 h-4 w-4' />
</Button>
)
}
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-[100]'>
<div className={cn(
'max-h-[360px] overflow-y-auto rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg',
popupClassName,
)}>
<div className='py-1'>
{
!!extraAuthorizationItems?.length && (
<div className='p-1'>
{
extraAuthorizationItems.map(credential => (
<Item
key={credential.id}
credential={credential}
disabled={disabled}
onItemClick={onItemClick}
disableRename
disableEdit
disableDelete
disableSetDefault
showSelectedIcon={showItemSelectedIcon}
selectedCredentialId={selectedCredentialId}
/>
))
}
</div>
)
}
{
!!oAuthCredentials.length && (
<div className='p-1'>
<div className={cn(
'system-xs-medium px-3 pb-0.5 pt-1 text-text-tertiary',
showItemSelectedIcon && 'pl-7',
)}>
OAuth
</div>
{
oAuthCredentials.map(credential => (
<Item
key={credential.id}
credential={credential}
disabled={disabled}
disableEdit
onDelete={openConfirm}
onSetDefault={handleSetDefault}
onRename={handleRename}
disableSetDefault={disableSetDefault}
onItemClick={onItemClick}
showSelectedIcon={showItemSelectedIcon}
selectedCredentialId={selectedCredentialId}
/>
))
}
</div>
)
}
{
!!apiKeyCredentials.length && (
<div className='p-1'>
<div className={cn(
'system-xs-medium px-3 pb-0.5 pt-1 text-text-tertiary',
showItemSelectedIcon && 'pl-7',
)}>
API Keys
</div>
{
apiKeyCredentials.map(credential => (
<Item
key={credential.id}
credential={credential}
disabled={disabled}
onDelete={openConfirm}
onEdit={handleEdit}
onSetDefault={handleSetDefault}
disableSetDefault={disableSetDefault}
disableRename
onItemClick={onItemClick}
onRename={handleRename}
showSelectedIcon={showItemSelectedIcon}
selectedCredentialId={selectedCredentialId}
/>
))
}
</div>
)
}
</div>
<div className='h-[1px] bg-divider-subtle'></div>
<div className='p-2'>
<Authorize
pluginPayload={pluginPayload}
theme='secondary'
showDivider={false}
canOAuth={canOAuth}
canApiKey={canApiKey}
disabled={disabled}
onUpdate={onUpdate}
/>
</div>
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
{
deleteCredentialId && (
<Confirm
isShow
title={t('datasetDocuments.list.delete.title')}
isDisabled={doingAction}
onCancel={closeConfirm}
onConfirm={handleConfirm}
/>
)
}
{
!!editValues && (
<ApiKeyModal
pluginPayload={pluginPayload}
editValues={editValues}
onClose={() => {
setEditValues(null)
pendingOperationCredentialId.current = null
}}
onRemove={handleRemove}
disabled={disabled || doingAction}
onUpdate={onUpdate}
/>
)
}
</>
)
}
export default memo(Authorized)

View File

@@ -0,0 +1,219 @@
import {
memo,
useMemo,
useState,
} from 'react'
import { useTranslation } from 'react-i18next'
import {
RiCheckLine,
RiDeleteBinLine,
RiEditLine,
RiEqualizer2Line,
} from '@remixicon/react'
import Indicator from '@/app/components/header/indicator'
import Badge from '@/app/components/base/badge'
import ActionButton from '@/app/components/base/action-button'
import Tooltip from '@/app/components/base/tooltip'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import cn from '@/utils/classnames'
import type { Credential } from '../types'
import { CredentialTypeEnum } from '../types'
type ItemProps = {
credential: Credential
disabled?: boolean
onDelete?: (id: string) => void
onEdit?: (id: string, values: Record<string, any>) => void
onSetDefault?: (id: string) => void
onRename?: (payload: {
credential_id: string
name: string
}) => void
disableRename?: boolean
disableEdit?: boolean
disableDelete?: boolean
disableSetDefault?: boolean
onItemClick?: (id: string) => void
showSelectedIcon?: boolean
selectedCredentialId?: string
}
const Item = ({
credential,
disabled,
onDelete,
onEdit,
onSetDefault,
onRename,
disableRename,
disableEdit,
disableDelete,
disableSetDefault,
onItemClick,
showSelectedIcon,
selectedCredentialId,
}: ItemProps) => {
const { t } = useTranslation()
const [renaming, setRenaming] = useState(false)
const [renameValue, setRenameValue] = useState(credential.name)
const isOAuth = credential.credential_type === CredentialTypeEnum.OAUTH2
const showAction = useMemo(() => {
return !(disableRename && disableEdit && disableDelete && disableSetDefault)
}, [disableRename, disableEdit, disableDelete, disableSetDefault])
return (
<div
key={credential.id}
className={cn(
'group flex h-8 items-center rounded-lg p-1 hover:bg-state-base-hover',
renaming && 'bg-state-base-hover',
)}
onClick={() => onItemClick?.(credential.id === '__workspace_default__' ? '' : credential.id)}
>
{
renaming && (
<div className='flex w-full items-center space-x-1'>
<Input
wrapperClassName='grow rounded-[6px]'
className='h-6'
value={renameValue}
onChange={e => setRenameValue(e.target.value)}
placeholder={t('common.placeholder.input')}
onClick={e => e.stopPropagation()}
/>
<Button
size='small'
variant='primary'
onClick={(e) => {
e.stopPropagation()
onRename?.({
credential_id: credential.id,
name: renameValue,
})
setRenaming(false)
}}
>
{t('common.operation.save')}
</Button>
<Button
size='small'
onClick={(e) => {
e.stopPropagation()
setRenaming(false)
}}
>
{t('common.operation.cancel')}
</Button>
</div>
)
}
{
!renaming && (
<div className='flex w-0 grow items-center space-x-1.5'>
{
showSelectedIcon && (
<div className='h-4 w-4'>
{
selectedCredentialId === credential.id && (
<RiCheckLine className='h-4 w-4 text-text-accent' />
)
}
</div>
)
}
<Indicator className='ml-2 mr-1.5 shrink-0' />
<div
className='system-md-regular truncate text-text-secondary'
title={credential.name}
>
{credential.name}
</div>
{
credential.is_default && (
<Badge className='shrink-0'>
{t('plugin.auth.default')}
</Badge>
)
}
</div>
)
}
{
showAction && !renaming && (
<div className='ml-2 hidden shrink-0 items-center group-hover:flex'>
{
!credential.is_default && !disableSetDefault && (
<Button
size='small'
disabled={disabled}
onClick={(e) => {
e.stopPropagation()
onSetDefault?.(credential.id)
}}
>
{t('plugin.auth.setDefault')}
</Button>
)
}
{
!disableRename && (
<Tooltip popupContent={t('common.operation.rename')}>
<ActionButton
disabled={disabled}
onClick={(e) => {
e.stopPropagation()
setRenaming(true)
setRenameValue(credential.name)
}}
>
<RiEditLine className='h-4 w-4 text-text-tertiary' />
</ActionButton>
</Tooltip>
)
}
{
!isOAuth && !disableEdit && (
<Tooltip popupContent={t('common.operation.edit')}>
<ActionButton
disabled={disabled}
onClick={(e) => {
e.stopPropagation()
onEdit?.(
credential.id,
{
...credential.credentials,
__name__: credential.name,
__credential_id__: credential.id,
},
)
}}
>
<RiEqualizer2Line className='h-4 w-4 text-text-tertiary' />
</ActionButton>
</Tooltip>
)
}
{
!disableDelete && (
<Tooltip popupContent={t('common.operation.delete')}>
<ActionButton
className='hover:bg-transparent'
disabled={disabled}
onClick={(e) => {
e.stopPropagation()
onDelete?.(credential.id)
}}
>
<RiDeleteBinLine className='h-4 w-4 text-text-tertiary hover:text-text-destructive' />
</ActionButton>
</Tooltip>
)
}
</div>
)
}
</div>
)
}
export default memo(Item)

View File

@@ -0,0 +1,88 @@
import {
useAddPluginCredential,
useDeletePluginCredential,
useDeletePluginOAuthCustomClient,
useGetPluginCredentialInfo,
useGetPluginCredentialSchema,
useGetPluginOAuthClientSchema,
useGetPluginOAuthUrl,
useInvalidPluginCredentialInfo,
useInvalidPluginOAuthClientSchema,
useSetPluginDefaultCredential,
useSetPluginOAuthCustomClient,
useUpdatePluginCredential,
} from '@/service/use-plugins-auth'
import { useGetApi } from './use-get-api'
import type { PluginPayload } from '../types'
import type { CredentialTypeEnum } from '../types'
export const useGetPluginCredentialInfoHook = (pluginPayload: PluginPayload, enable?: boolean) => {
const apiMap = useGetApi(pluginPayload)
return useGetPluginCredentialInfo(enable ? apiMap.getCredentialInfo : '')
}
export const useDeletePluginCredentialHook = (pluginPayload: PluginPayload) => {
const apiMap = useGetApi(pluginPayload)
return useDeletePluginCredential(apiMap.deleteCredential)
}
export const useInvalidPluginCredentialInfoHook = (pluginPayload: PluginPayload) => {
const apiMap = useGetApi(pluginPayload)
return useInvalidPluginCredentialInfo(apiMap.getCredentialInfo)
}
export const useSetPluginDefaultCredentialHook = (pluginPayload: PluginPayload) => {
const apiMap = useGetApi(pluginPayload)
return useSetPluginDefaultCredential(apiMap.setDefaultCredential)
}
export const useGetPluginCredentialSchemaHook = (pluginPayload: PluginPayload, credentialType: CredentialTypeEnum) => {
const apiMap = useGetApi(pluginPayload)
return useGetPluginCredentialSchema(apiMap.getCredentialSchema(credentialType))
}
export const useAddPluginCredentialHook = (pluginPayload: PluginPayload) => {
const apiMap = useGetApi(pluginPayload)
return useAddPluginCredential(apiMap.addCredential)
}
export const useUpdatePluginCredentialHook = (pluginPayload: PluginPayload) => {
const apiMap = useGetApi(pluginPayload)
return useUpdatePluginCredential(apiMap.updateCredential)
}
export const useGetPluginOAuthUrlHook = (pluginPayload: PluginPayload) => {
const apiMap = useGetApi(pluginPayload)
return useGetPluginOAuthUrl(apiMap.getOauthUrl)
}
export const useGetPluginOAuthClientSchemaHook = (pluginPayload: PluginPayload) => {
const apiMap = useGetApi(pluginPayload)
return useGetPluginOAuthClientSchema(apiMap.getOauthClientSchema)
}
export const useInvalidPluginOAuthClientSchemaHook = (pluginPayload: PluginPayload) => {
const apiMap = useGetApi(pluginPayload)
return useInvalidPluginOAuthClientSchema(apiMap.getOauthClientSchema)
}
export const useSetPluginOAuthCustomClientHook = (pluginPayload: PluginPayload) => {
const apiMap = useGetApi(pluginPayload)
return useSetPluginOAuthCustomClient(apiMap.setCustomOauthClient)
}
export const useDeletePluginOAuthCustomClientHook = (pluginPayload: PluginPayload) => {
const apiMap = useGetApi(pluginPayload)
return useDeletePluginOAuthCustomClient(apiMap.deleteCustomOAuthClient)
}

View File

@@ -0,0 +1,41 @@
import {
AuthCategory,
} from '../types'
import type {
CredentialTypeEnum,
PluginPayload,
} from '../types'
export const useGetApi = ({ category = AuthCategory.tool, provider }: PluginPayload) => {
if (category === AuthCategory.tool) {
return {
getCredentialInfo: `/workspaces/current/tool-provider/builtin/${provider}/credential/info`,
setDefaultCredential: `/workspaces/current/tool-provider/builtin/${provider}/default-credential`,
getCredentials: `/workspaces/current/tool-provider/builtin/${provider}/credentials`,
addCredential: `/workspaces/current/tool-provider/builtin/${provider}/add`,
updateCredential: `/workspaces/current/tool-provider/builtin/${provider}/update`,
deleteCredential: `/workspaces/current/tool-provider/builtin/${provider}/delete`,
getCredentialSchema: (credential_type: CredentialTypeEnum) => `/workspaces/current/tool-provider/builtin/${provider}/credential/schema/${credential_type}`,
getOauthUrl: `/oauth/plugin/${provider}/tool/authorization-url`,
getOauthClientSchema: `/workspaces/current/tool-provider/builtin/${provider}/oauth/client-schema`,
setCustomOauthClient: `/workspaces/current/tool-provider/builtin/${provider}/oauth/custom-client`,
getCustomOAuthClientValues: `/workspaces/current/tool-provider/builtin/${provider}/oauth/custom-client`,
deleteCustomOAuthClient: `/workspaces/current/tool-provider/builtin/${provider}/oauth/custom-client`,
}
}
return {
getCredentialInfo: '',
setDefaultCredential: '',
getCredentials: '',
addCredential: '',
updateCredential: '',
deleteCredential: '',
getCredentialSchema: () => '',
getOauthUrl: '',
getOauthClientSchema: '',
setCustomOauthClient: '',
getCustomOAuthClientValues: '',
deleteCustomOAuthClient: '',
}
}

View File

@@ -0,0 +1,25 @@
import { useAppContext } from '@/context/app-context'
import {
useGetPluginCredentialInfoHook,
useInvalidPluginCredentialInfoHook,
} from './use-credential'
import { CredentialTypeEnum } from '../types'
import type { PluginPayload } from '../types'
export const usePluginAuth = (pluginPayload: PluginPayload, enable?: boolean) => {
const { data } = useGetPluginCredentialInfoHook(pluginPayload, enable)
const { isCurrentWorkspaceManager } = useAppContext()
const isAuthorized = !!data?.credentials.length
const canOAuth = data?.supported_credential_types.includes(CredentialTypeEnum.OAUTH2)
const canApiKey = data?.supported_credential_types.includes(CredentialTypeEnum.API_KEY)
const invalidPluginCredentialInfo = useInvalidPluginCredentialInfoHook(pluginPayload)
return {
isAuthorized,
canOAuth,
canApiKey,
credentials: data?.credentials || [],
disabled: !isCurrentWorkspaceManager,
invalidPluginCredentialInfo,
}
}

View File

@@ -0,0 +1,6 @@
export { default as PluginAuth } from './plugin-auth'
export { default as Authorized } from './authorized'
export { default as AuthorizedInNode } from './authorized-in-node'
export { default as PluginAuthInAgent } from './plugin-auth-in-agent'
export { usePluginAuth } from './hooks/use-plugin-auth'
export * from './types'

View File

@@ -0,0 +1,123 @@
import {
memo,
useCallback,
useState,
} from 'react'
import { RiArrowDownSLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import Authorize from './authorize'
import Authorized from './authorized'
import type {
Credential,
PluginPayload,
} from './types'
import { usePluginAuth } from './hooks/use-plugin-auth'
import Button from '@/app/components/base/button'
import Indicator from '@/app/components/header/indicator'
import cn from '@/utils/classnames'
type PluginAuthInAgentProps = {
pluginPayload: PluginPayload
credentialId?: string
onAuthorizationItemClick?: (id: string) => void
}
const PluginAuthInAgent = ({
pluginPayload,
credentialId,
onAuthorizationItemClick,
}: PluginAuthInAgentProps) => {
const { t } = useTranslation()
const [isOpen, setIsOpen] = useState(false)
const {
isAuthorized,
canOAuth,
canApiKey,
credentials,
disabled,
invalidPluginCredentialInfo,
} = usePluginAuth(pluginPayload, true)
const extraAuthorizationItems: Credential[] = [
{
id: '__workspace_default__',
name: t('plugin.auth.workspaceDefault'),
provider: '',
is_default: !credentialId,
isWorkspaceDefault: true,
},
]
const handleAuthorizationItemClick = useCallback((id: string) => {
onAuthorizationItemClick?.(id)
setIsOpen(false)
}, [
onAuthorizationItemClick,
setIsOpen,
])
const renderTrigger = useCallback((isOpen?: boolean) => {
let label = ''
let removed = false
if (!credentialId) {
label = t('plugin.auth.workspaceDefault')
}
else {
const credential = credentials.find(c => c.id === credentialId)
label = credential ? credential.name : t('plugin.auth.authRemoved')
removed = !credential
}
return (
<Button
className={cn(
'w-full',
isOpen && 'bg-components-button-secondary-bg-hover',
removed && 'text-text-destructive',
)}>
<Indicator
className='mr-2'
color={removed ? 'red' : 'green'}
/>
{label}
<RiArrowDownSLine className='ml-0.5 h-4 w-4' />
</Button>
)
}, [credentialId, credentials, t])
return (
<>
{
!isAuthorized && (
<Authorize
pluginPayload={pluginPayload}
canOAuth={canOAuth}
canApiKey={canApiKey}
disabled={disabled}
onUpdate={invalidPluginCredentialInfo}
/>
)
}
{
isAuthorized && (
<Authorized
pluginPayload={pluginPayload}
credentials={credentials}
canOAuth={canOAuth}
canApiKey={canApiKey}
disabled={disabled}
disableSetDefault
onItemClick={handleAuthorizationItemClick}
extraAuthorizationItems={extraAuthorizationItems}
showItemSelectedIcon
renderTrigger={renderTrigger}
isOpen={isOpen}
onOpenChange={setIsOpen}
selectedCredentialId={credentialId || '__workspace_default__'}
onUpdate={invalidPluginCredentialInfo}
/>
)
}
</>
)
}
export default memo(PluginAuthInAgent)

View File

@@ -0,0 +1,59 @@
import { memo } from 'react'
import Authorize from './authorize'
import Authorized from './authorized'
import type { PluginPayload } from './types'
import { usePluginAuth } from './hooks/use-plugin-auth'
import cn from '@/utils/classnames'
type PluginAuthProps = {
pluginPayload: PluginPayload
children?: React.ReactNode
className?: string
}
const PluginAuth = ({
pluginPayload,
children,
className,
}: PluginAuthProps) => {
const {
isAuthorized,
canOAuth,
canApiKey,
credentials,
disabled,
invalidPluginCredentialInfo,
} = usePluginAuth(pluginPayload, !!pluginPayload.provider)
return (
<div className={cn(!isAuthorized && className)}>
{
!isAuthorized && (
<Authorize
pluginPayload={pluginPayload}
canOAuth={canOAuth}
canApiKey={canApiKey}
disabled={disabled}
onUpdate={invalidPluginCredentialInfo}
/>
)
}
{
isAuthorized && !children && (
<Authorized
pluginPayload={pluginPayload}
credentials={credentials}
canOAuth={canOAuth}
canApiKey={canApiKey}
disabled={disabled}
onUpdate={invalidPluginCredentialInfo}
/>
)
}
{
isAuthorized && children
}
</div>
)
}
export default memo(PluginAuth)

View File

@@ -0,0 +1,25 @@
export enum AuthCategory {
tool = 'tool',
datasource = 'datasource',
model = 'model',
}
export type PluginPayload = {
category: AuthCategory
provider: string
}
export enum CredentialTypeEnum {
OAUTH2 = 'oauth2',
API_KEY = 'api-key',
}
export type Credential = {
id: string
name: string
provider: string
credential_type?: CredentialTypeEnum
is_default: boolean
credentials?: Record<string, any>
isWorkspaceDefault?: boolean
}

View File

@@ -0,0 +1,10 @@
export const transformFormSchemasSecretInput = (isPristineSecretInputNames: string[], values: Record<string, any>) => {
const transformedValues: Record<string, any> = { ...values }
isPristineSecretInputNames.forEach((name) => {
if (transformedValues[name])
transformedValues[name] = '[__HIDDEN__]'
})
return transformedValues
}

View File

@@ -1,17 +1,9 @@
import React, { useMemo, useState } from 'react'
import React, { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { useAppContext } from '@/context/app-context'
import Button from '@/app/components/base/button'
import Toast from '@/app/components/base/toast'
import Indicator from '@/app/components/header/indicator'
import ToolItem from '@/app/components/tools/provider/tool-item'
import ConfigCredential from '@/app/components/tools/setting/build-in/config-credentials'
import {
useAllToolProviders,
useBuiltinTools,
useInvalidateAllToolProviders,
useRemoveProviderCredentials,
useUpdateProviderCredentials,
} from '@/service/use-tools'
import type { PluginDetail } from '@/app/components/plugins/types'
@@ -23,35 +15,14 @@ const ActionList = ({
detail,
}: Props) => {
const { t } = useTranslation()
const { isCurrentWorkspaceManager } = useAppContext()
const providerBriefInfo = detail.declaration.tool.identity
const providerKey = `${detail.plugin_id}/${providerBriefInfo.name}`
const { data: collectionList = [] } = useAllToolProviders()
const invalidateAllToolProviders = useInvalidateAllToolProviders()
const provider = useMemo(() => {
return collectionList.find(collection => collection.name === providerKey)
}, [collectionList, providerKey])
const { data } = useBuiltinTools(providerKey)
const [showSettingAuth, setShowSettingAuth] = useState(false)
const handleCredentialSettingUpdate = () => {
invalidateAllToolProviders()
Toast.notify({
type: 'success',
message: t('common.api.actionSuccess'),
})
setShowSettingAuth(false)
}
const { mutate: updatePermission, isPending } = useUpdateProviderCredentials({
onSuccess: handleCredentialSettingUpdate,
})
const { mutate: removePermission } = useRemoveProviderCredentials({
onSuccess: handleCredentialSettingUpdate,
})
if (!data || !provider)
return null
@@ -60,26 +31,7 @@ const ActionList = ({
<div className='mb-1 py-1'>
<div className='system-sm-semibold-uppercase mb-1 flex h-6 items-center justify-between text-text-secondary'>
{t('plugin.detailPanel.actionNum', { num: data.length, action: data.length > 1 ? 'actions' : 'action' })}
{provider.is_team_authorization && provider.allow_delete && (
<Button
variant='secondary'
size='small'
onClick={() => setShowSettingAuth(true)}
disabled={!isCurrentWorkspaceManager}
>
<Indicator className='mr-2' color={'green'} />
{t('tools.auth.authorized')}
</Button>
)}
</div>
{!provider.is_team_authorization && provider.allow_delete && (
<Button
variant='primary'
className='w-full'
onClick={() => setShowSettingAuth(true)}
disabled={!isCurrentWorkspaceManager}
>{t('workflow.nodes.tool.authorize')}</Button>
)}
</div>
<div className='flex flex-col gap-2'>
{data.map(tool => (
@@ -93,18 +45,6 @@ const ActionList = ({
/>
))}
</div>
{showSettingAuth && (
<ConfigCredential
collection={provider}
onCancel={() => setShowSettingAuth(false)}
onSaved={async value => updatePermission({
providerName: provider.name,
credentials: value,
})}
onRemove={async () => removePermission(provider.name)}
isSaving={isPending}
/>
)}
</div>
)
}

View File

@@ -36,6 +36,9 @@ import { useInvalidateAllToolProviders } from '@/service/use-tools'
import { API_PREFIX } from '@/config'
import cn from '@/utils/classnames'
import { getMarketplaceUrl } from '@/utils/var'
import { PluginAuth } from '@/app/components/plugins/plugin-auth'
import { AuthCategory } from '@/app/components/plugins/plugin-auth'
import { useAllToolProviders } from '@/service/use-tools'
const i18nPrefix = 'plugin.action'
@@ -68,7 +71,14 @@ const DetailHeader = ({
meta,
plugin_id,
} = detail
const { author, category, name, label, description, icon, verified } = detail.declaration
const { author, category, name, label, description, icon, verified, tool } = detail.declaration
const isTool = category === PluginType.tool
const providerBriefInfo = tool?.identity
const providerKey = `${plugin_id}/${providerBriefInfo?.name}`
const { data: collectionList = [] } = useAllToolProviders(isTool)
const provider = useMemo(() => {
return collectionList.find(collection => collection.name === providerKey)
}, [collectionList, providerKey])
const isFromGitHub = source === PluginSource.github
const isFromMarketplace = source === PluginSource.marketplace
@@ -262,7 +272,17 @@ const DetailHeader = ({
</ActionButton>
</div>
</div>
<Description className='mt-3' text={description[locale]} descriptionLineRows={2}></Description>
<Description className='mb-2 mt-3 h-auto' text={description[locale]} descriptionLineRows={2}></Description>
{
category === PluginType.tool && (
<PluginAuth
pluginPayload={{
provider: provider?.name || '',
category: AuthCategory.tool,
}}
/>
)
}
{isShowPluginInfo && (
<PluginInfo
repository={isFromGitHub ? meta?.repo : ''}

View File

@@ -3,9 +3,6 @@ import type { FC } from 'react'
import React, { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Link from 'next/link'
import {
RiArrowLeftLine,
} from '@remixicon/react'
import {
PortalToFollowElem,
PortalToFollowElemContent,
@@ -15,24 +12,17 @@ import ToolTrigger from '@/app/components/plugins/plugin-detail-panel/tool-selec
import ToolItem from '@/app/components/plugins/plugin-detail-panel/tool-selector/tool-item'
import ToolPicker from '@/app/components/workflow/block-selector/tool-picker'
import ToolForm from '@/app/components/workflow/nodes/tool/components/tool-form'
import Button from '@/app/components/base/button'
import Indicator from '@/app/components/header/indicator'
import ToolCredentialForm from '@/app/components/plugins/plugin-detail-panel/tool-selector/tool-credentials-form'
import Toast from '@/app/components/base/toast'
import Textarea from '@/app/components/base/textarea'
import Divider from '@/app/components/base/divider'
import TabSlider from '@/app/components/base/tab-slider-plain'
import ReasoningConfigForm from '@/app/components/plugins/plugin-detail-panel/tool-selector/reasoning-config-form'
import { generateFormValue, getPlainValue, getStructureValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
import { useAppContext } from '@/context/app-context'
import {
useAllBuiltInTools,
useAllCustomTools,
useAllMCPTools,
useAllWorkflowTools,
useInvalidateAllBuiltInTools,
useUpdateProviderCredentials,
} from '@/service/use-tools'
import { useInvalidateInstalledPluginList } from '@/service/use-plugins'
import { usePluginInstalledCheck } from '@/app/components/plugins/plugin-detail-panel/tool-selector/hooks'
@@ -46,6 +36,10 @@ import { MARKETPLACE_API_PREFIX } from '@/config'
import type { Node } from 'reactflow'
import type { NodeOutPutVar } from '@/app/components/workflow/types'
import cn from '@/utils/classnames'
import {
AuthCategory,
PluginAuthInAgent,
} from '@/app/components/plugins/plugin-auth'
type Props = {
disabled?: boolean
@@ -196,23 +190,6 @@ const ToolSelector: FC<Props> = ({
} as any)
}
// authorization
const { isCurrentWorkspaceManager } = useAppContext()
const [isShowSettingAuth, setShowSettingAuth] = useState(false)
const handleCredentialSettingUpdate = () => {
invalidateAllBuiltinTools()
Toast.notify({
type: 'success',
message: t('common.api.actionSuccess'),
})
setShowSettingAuth(false)
onShowChange(false)
}
const { mutate: updatePermission } = useUpdateProviderCredentials({
onSuccess: handleCredentialSettingUpdate,
})
// install from marketplace
const currentTool = useMemo(() => {
return currentProvider?.tools.find(tool => tool.name === value?.tool_name)
@@ -226,6 +203,12 @@ const ToolSelector: FC<Props> = ({
invalidateAllBuiltinTools()
invalidateInstalledPluginList()
}
const handleAuthorizationItemClick = (id: string) => {
onSelect({
...value,
credential_id: id,
} as any)
}
return (
<>
@@ -264,7 +247,6 @@ const ToolSelector: FC<Props> = ({
onSwitchChange={handleEnabledChange}
onDelete={onDelete}
noAuth={currentProvider && currentTool && !currentProvider.is_team_authorization}
onAuth={() => setShowSettingAuth(true)}
uninstalled={!currentProvider && inMarketPlace}
versionMismatch={currentProvider && inMarketPlace && !currentTool}
installInfo={manifest?.latest_package_identifier}
@@ -284,171 +266,131 @@ const ToolSelector: FC<Props> = ({
)}
</PortalToFollowElemTrigger>
<PortalToFollowElemContent>
<div className={cn('relative max-h-[642px] min-h-20 w-[361px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur pb-4 shadow-lg backdrop-blur-sm', !isShowSettingAuth && 'overflow-y-auto pb-2')}>
{!isShowSettingAuth && (
<>
<div className='system-xl-semibold px-4 pb-1 pt-3.5 text-text-primary'>{t(`plugin.detailPanel.toolSelector.${isEdit ? 'toolSetting' : 'title'}`)}</div>
{/* base form */}
<div className='flex flex-col gap-3 px-4 py-2'>
<div className='flex flex-col gap-1'>
<div className='system-sm-semibold flex h-6 items-center text-text-secondary'>{t('plugin.detailPanel.toolSelector.toolLabel')}</div>
<ToolPicker
placement='bottom'
offset={offset}
trigger={
<ToolTrigger
open={panelShowState || isShowChooseTool}
value={value}
provider={currentProvider}
/>
}
isShow={panelShowState || isShowChooseTool}
onShowChange={trigger ? onPanelShowStateChange as any : setIsShowChooseTool}
disabled={false}
supportAddCustomTool
onSelect={handleSelectTool}
onSelectMultiple={handleSelectMultipleTool}
scope={scope}
selectedTools={selectedTools}
canChooseMCPTool={canChooseMCPTool}
/>
</div>
<div className='flex flex-col gap-1'>
<div className='system-sm-semibold flex h-6 items-center text-text-secondary'>{t('plugin.detailPanel.toolSelector.descriptionLabel')}</div>
<Textarea
className='resize-none'
placeholder={t('plugin.detailPanel.toolSelector.descriptionPlaceholder')}
value={value?.extra?.description || ''}
onChange={handleDescriptionChange}
disabled={!value?.provider_name}
/>
</div>
</div>
{/* authorization */}
{currentProvider && currentProvider.type === CollectionType.builtIn && currentProvider.allow_delete && (
<>
<Divider className='my-1 w-full' />
<div className='px-4 py-2'>
{!currentProvider.is_team_authorization && (
<Button
variant='primary'
className={cn('w-full shrink-0')}
onClick={() => setShowSettingAuth(true)}
disabled={!isCurrentWorkspaceManager}
>
{t('tools.auth.unauthorized')}
</Button>
)}
{currentProvider.is_team_authorization && (
<Button
variant='secondary'
className={cn('w-full shrink-0')}
onClick={() => setShowSettingAuth(true)}
disabled={!isCurrentWorkspaceManager}
>
<Indicator className='mr-2' color={'green'} />
{t('tools.auth.authorized')}
</Button>
)}
</div>
</>
)}
{/* tool settings */}
{(currentToolSettings.length > 0 || currentToolParams.length > 0) && currentProvider?.is_team_authorization && (
<>
<Divider className='my-1 w-full' />
{/* tabs */}
{nodeId && showTabSlider && (
<TabSlider
className='mt-1 shrink-0 px-4'
itemClassName='py-3'
noBorderBottom
smallItem
value={currType}
onChange={(value) => {
setCurrType(value)
}}
options={[
{ value: 'settings', text: t('plugin.detailPanel.toolSelector.settings')! },
{ value: 'params', text: t('plugin.detailPanel.toolSelector.params')! },
]}
<div className={cn('relative max-h-[642px] min-h-20 w-[361px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur pb-4 shadow-lg backdrop-blur-sm', 'overflow-y-auto pb-2')}>
<>
<div className='system-xl-semibold px-4 pb-1 pt-3.5 text-text-primary'>{t(`plugin.detailPanel.toolSelector.${isEdit ? 'toolSetting' : 'title'}`)}</div>
{/* base form */}
<div className='flex flex-col gap-3 px-4 py-2'>
<div className='flex flex-col gap-1'>
<div className='system-sm-semibold flex h-6 items-center text-text-secondary'>{t('plugin.detailPanel.toolSelector.toolLabel')}</div>
<ToolPicker
placement='bottom'
offset={offset}
trigger={
<ToolTrigger
open={panelShowState || isShowChooseTool}
value={value}
provider={currentProvider}
/>
)}
{nodeId && showTabSlider && currType === 'params' && (
<div className='px-4 py-2'>
}
isShow={panelShowState || isShowChooseTool}
onShowChange={trigger ? onPanelShowStateChange as any : setIsShowChooseTool}
disabled={false}
supportAddCustomTool
onSelect={handleSelectTool}
onSelectMultiple={handleSelectMultipleTool}
scope={scope}
selectedTools={selectedTools}
canChooseMCPTool={canChooseMCPTool}
/>
</div>
<div className='flex flex-col gap-1'>
<div className='system-sm-semibold flex h-6 items-center text-text-secondary'>{t('plugin.detailPanel.toolSelector.descriptionLabel')}</div>
<Textarea
className='resize-none'
placeholder={t('plugin.detailPanel.toolSelector.descriptionPlaceholder')}
value={value?.extra?.description || ''}
onChange={handleDescriptionChange}
disabled={!value?.provider_name}
/>
</div>
</div>
{/* authorization */}
{currentProvider && currentProvider.type === CollectionType.builtIn && currentProvider.allow_delete && (
<>
<Divider className='my-1 w-full' />
<div className='px-4 py-2'>
<PluginAuthInAgent
pluginPayload={{
provider: currentProvider.name,
category: AuthCategory.tool,
}}
credentialId={value?.credential_id}
onAuthorizationItemClick={handleAuthorizationItemClick}
/>
</div>
</>
)}
{/* tool settings */}
{(currentToolSettings.length > 0 || currentToolParams.length > 0) && currentProvider?.is_team_authorization && (
<>
<Divider className='my-1 w-full' />
{/* tabs */}
{nodeId && showTabSlider && (
<TabSlider
className='mt-1 shrink-0 px-4'
itemClassName='py-3'
noBorderBottom
smallItem
value={currType}
onChange={(value) => {
setCurrType(value)
}}
options={[
{ value: 'settings', text: t('plugin.detailPanel.toolSelector.settings')! },
{ value: 'params', text: t('plugin.detailPanel.toolSelector.params')! },
]}
/>
)}
{nodeId && showTabSlider && currType === 'params' && (
<div className='px-4 py-2'>
<div className='system-xs-regular text-text-tertiary'>{t('plugin.detailPanel.toolSelector.paramsTip1')}</div>
<div className='system-xs-regular text-text-tertiary'>{t('plugin.detailPanel.toolSelector.paramsTip2')}</div>
</div>
)}
{/* user settings only */}
{userSettingsOnly && (
<div className='p-4 pb-1'>
<div className='system-sm-semibold-uppercase text-text-primary'>{t('plugin.detailPanel.toolSelector.settings')}</div>
</div>
)}
{/* reasoning config only */}
{nodeId && reasoningConfigOnly && (
<div className='mb-1 p-4 pb-1'>
<div className='system-sm-semibold-uppercase text-text-primary'>{t('plugin.detailPanel.toolSelector.params')}</div>
<div className='pb-1'>
<div className='system-xs-regular text-text-tertiary'>{t('plugin.detailPanel.toolSelector.paramsTip1')}</div>
<div className='system-xs-regular text-text-tertiary'>{t('plugin.detailPanel.toolSelector.paramsTip2')}</div>
</div>
)}
{/* user settings only */}
{userSettingsOnly && (
<div className='p-4 pb-1'>
<div className='system-sm-semibold-uppercase text-text-primary'>{t('plugin.detailPanel.toolSelector.settings')}</div>
</div>
)}
{/* reasoning config only */}
{nodeId && reasoningConfigOnly && (
<div className='mb-1 p-4 pb-1'>
<div className='system-sm-semibold-uppercase text-text-primary'>{t('plugin.detailPanel.toolSelector.params')}</div>
<div className='pb-1'>
<div className='system-xs-regular text-text-tertiary'>{t('plugin.detailPanel.toolSelector.paramsTip1')}</div>
<div className='system-xs-regular text-text-tertiary'>{t('plugin.detailPanel.toolSelector.paramsTip2')}</div>
</div>
</div>
)}
{/* user settings form */}
{(currType === 'settings' || userSettingsOnly) && (
<div className='px-4 py-2'>
<ToolForm
inPanel
readOnly={false}
nodeId={nodeId}
schema={settingsFormSchemas as any}
value={getPlainValue(value?.settings || {})}
onChange={handleSettingsFormChange}
/>
</div>
)}
{/* reasoning config form */}
{nodeId && (currType === 'params' || reasoningConfigOnly) && (
<ReasoningConfigForm
value={value?.parameters || {}}
onChange={handleParamsFormChange}
schemas={paramsFormSchemas as any}
nodeOutputVars={nodeOutputVars}
availableNodes={availableNodes}
</div>
)}
{/* user settings form */}
{(currType === 'settings' || userSettingsOnly) && (
<div className='px-4 py-2'>
<ToolForm
inPanel
readOnly={false}
nodeId={nodeId}
schema={settingsFormSchemas as any}
value={getPlainValue(value?.settings || {})}
onChange={handleSettingsFormChange}
/>
)}
</>
)}
</>
)}
{/* authorization panel */}
{isShowSettingAuth && currentProvider && (
<>
<div className='relative flex flex-col gap-1 pt-3.5'>
<div className='absolute -top-2 left-2 w-[345px] rounded-t-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur pt-2 backdrop-blur-sm'></div>
<div
className='system-xs-semibold-uppercase flex h-6 cursor-pointer items-center gap-1 px-3 text-text-accent-secondary'
onClick={() => setShowSettingAuth(false)}
>
<RiArrowLeftLine className='h-4 w-4' />
BACK
</div>
<div className='system-xl-semibold px-4 text-text-primary'>{t('tools.auth.setupModalTitle')}</div>
<div className='system-xs-regular px-4 text-text-tertiary'>{t('tools.auth.setupModalTitleDescription')}</div>
</div>
<ToolCredentialForm
collection={currentProvider}
onCancel={() => setShowSettingAuth(false)}
onSaved={async value => updatePermission({
providerName: currentProvider.name,
credentials: value,
})}
/>
</>
)}
</div>
)}
{/* reasoning config form */}
{nodeId && (currType === 'params' || reasoningConfigOnly) && (
<ReasoningConfigForm
value={value?.parameters || {}}
onChange={handleParamsFormChange}
schemas={paramsFormSchemas as any}
nodeOutputVars={nodeOutputVars}
availableNodes={availableNodes}
nodeId={nodeId}
/>
)}
</>
)}
</>
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>

View File

@@ -30,7 +30,6 @@ type Props = {
onSwitchChange?: (value: boolean) => void
onDelete?: () => void
noAuth?: boolean
onAuth?: () => void
isError?: boolean
errorTip?: any
uninstalled?: boolean
@@ -38,6 +37,7 @@ type Props = {
onInstall?: () => void
versionMismatch?: boolean
open: boolean
authRemoved?: boolean
canChooseMCPTool?: boolean,
}
@@ -53,13 +53,13 @@ const ToolItem = ({
onSwitchChange,
onDelete,
noAuth,
onAuth,
uninstalled,
installInfo,
onInstall,
isError,
errorTip,
versionMismatch,
authRemoved,
canChooseMCPTool,
}: Props) => {
const { t } = useTranslation()
@@ -125,11 +125,17 @@ const ToolItem = ({
<McpToolNotSupportTooltip />
)}
{!isError && !uninstalled && !versionMismatch && noAuth && (
<Button variant='secondary' size='small' onClick={onAuth}>
<Button variant='secondary' size='small'>
{t('tools.notAuthorized')}
<Indicator className='ml-2' color='orange' />
</Button>
)}
{!isError && !uninstalled && !versionMismatch && authRemoved && (
<Button variant='secondary' size='small'>
{t('plugin.auth.authRemoved')}
<Indicator className='ml-2' color='red' />
</Button>
)}
{!isError && !uninstalled && versionMismatch && installInfo && (
<div onClick={e => e.stopPropagation()}>
<SwitchPluginVersion

View File

@@ -33,6 +33,7 @@ export type ToolDefaultValue = {
params: Record<string, any>
paramSchemas: Record<string, any>[]
output_schema: Record<string, any>
credential_id?: string
meta?: PluginMeta
}
@@ -46,4 +47,5 @@ export type ToolValue = {
parameters?: Record<string, any>
enabled?: boolean
extra?: Record<string, any>
credential_id?: string
}

View File

@@ -59,6 +59,12 @@ import { useLogs } from '@/app/components/workflow/run/hooks'
import PanelWrap from '../before-run-form/panel-wrap'
import SpecialResultPanel from '@/app/components/workflow/run/special-result-panel'
import { Stop } from '@/app/components/base/icons/src/vender/line/mediaAndDevices'
import {
AuthorizedInNode,
PluginAuth,
} from '@/app/components/plugins/plugin-auth'
import { AuthCategory } from '@/app/components/plugins/plugin-auth'
import { canFindTool } from '@/utils'
type BasePanelProps = {
children: ReactNode
@@ -221,6 +227,22 @@ const BasePanel: FC<BasePanelProps> = ({
return {}
})()
const buildInTools = useStore(s => s.buildInTools)
const currCollection = useMemo(() => {
return buildInTools.find(item => canFindTool(item.id, data.provider_id))
}, [buildInTools, data.provider_id])
const showPluginAuth = useMemo(() => {
return data.type === BlockEnum.Tool && currCollection?.allow_delete
}, [currCollection, data.type])
const handleAuthorizationItemClick = useCallback((credential_id: string) => {
handleNodeDataUpdateWithSyncDraft({
id,
data: {
credential_id,
},
})
}, [handleNodeDataUpdateWithSyncDraft, id])
if(logParams.showSpecialResultPanel) {
return (
<div className={cn(
@@ -353,12 +375,42 @@ const BasePanel: FC<BasePanelProps> = ({
onChange={handleDescriptionChange}
/>
</div>
<div className='pl-4'>
<Tab
value={tabType}
onChange={setTabType}
/>
</div>
{
showPluginAuth && (
<PluginAuth
className='px-4 pb-2'
pluginPayload={{
provider: currCollection?.name || '',
category: AuthCategory.tool,
}}
>
<div className='flex items-center justify-between pl-4 pr-3'>
<Tab
value={tabType}
onChange={setTabType}
/>
<AuthorizedInNode
pluginPayload={{
provider: currCollection?.name || '',
category: AuthCategory.tool,
}}
onAuthorizationItemClick={handleAuthorizationItemClick}
credentialId={data.credential_id}
/>
</div>
</PluginAuth>
)
}
{
!showPluginAuth && (
<div className='flex items-center justify-between pl-4 pr-3'>
<Tab
value={tabType}
onChange={setTabType}
/>
</div>
)
}
<Split />
</div>

View File

@@ -5,10 +5,8 @@ import Split from '../_base/components/split'
import type { ToolNodeType } from './types'
import useConfig from './use-config'
import ToolForm from './components/tool-form'
import Button from '@/app/components/base/button'
import Field from '@/app/components/workflow/nodes/_base/components/field'
import type { NodePanelProps } from '@/app/components/workflow/types'
import ConfigCredential from '@/app/components/tools/setting/build-in/config-credentials'
import Loading from '@/app/components/base/loading'
import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars'
import StructureOutputItem from '@/app/components/workflow/nodes/_base/components/variable/object-child-tree-panel/show'
@@ -32,10 +30,6 @@ const Panel: FC<NodePanelProps<ToolNodeType>> = ({
setToolSettingValue,
currCollection,
isShowAuthBtn,
showSetAuth,
showSetAuthModal,
hideSetAuthModal,
handleSaveAuth,
isLoading,
outputSchema,
hasObjectOutput,
@@ -52,19 +46,6 @@ const Panel: FC<NodePanelProps<ToolNodeType>> = ({
return (
<div className='pt-2'>
{!readOnly && isShowAuthBtn && (
<>
<div className='px-4'>
<Button
variant='primary'
className='w-full'
onClick={showSetAuthModal}
>
{t(`${i18nPrefix}.authorize`)}
</Button>
</div>
</>
)}
{!isShowAuthBtn && (
<div className='relative'>
{toolInputVarSchema.length > 0 && (
@@ -109,15 +90,6 @@ const Panel: FC<NodePanelProps<ToolNodeType>> = ({
</div>
)}
{showSetAuth && (
<ConfigCredential
collection={currCollection!}
onCancel={hideSetAuthModal}
onSaved={handleSaveAuth}
isHideRemoveBtn
/>
)}
<div>
<OutputVars>
<>

View File

@@ -93,6 +93,7 @@ export type CommonNodeType<T = {}> = {
error_strategy?: ErrorHandleTypeEnum
retry_config?: WorkflowRetryConfig
default_value?: DefaultValueForm[]
credential_id?: string
} & T & Partial<Pick<ToolDefaultValue, 'provider_id' | 'provider_type' | 'provider_name' | 'tool_name'>>
export type CommonEdgeType = {

View File

@@ -214,6 +214,29 @@ const translation = {
requestAPlugin: 'Request a plugin',
publishPlugins: 'Publish plugins',
difyVersionNotCompatible: 'The current Dify version is not compatible with this plugin, please upgrade to the minimum version required: {{minimalDifyVersion}}',
auth: {
default: 'Default',
custom: 'Custom',
setDefault: 'Set as default',
useOAuth: 'Use OAuth',
useOAuthAuth: 'Use OAuth Authorization',
addOAuth: 'Add OAuth',
setupOAuth: 'Setup OAuth Client',
useApi: 'Use API Key',
addApi: 'Add API Key',
useApiAuth: 'API Key Authorization Configuration',
useApiAuthDesc: 'After configuring credentials, all members within the workspace can use this tool when orchestrating applications.',
oauthClientSettings: 'OAuth Client Settings',
saveOnly: 'Save only',
saveAndAuth: 'Save and Authorize',
authorization: 'Authorization',
authorizations: 'Authorizations',
authorizationName: 'Authorization Name',
workspaceDefault: 'Workspace Default',
authRemoved: 'Auth removed',
clientInfo: 'As no system client secrets found for this tool provider, setup it manually is required, for redirect_uri, please use',
oauthClient: 'OAuth Client',
},
}
export default translation

View File

@@ -214,6 +214,29 @@ const translation = {
requestAPlugin: '申请插件',
publishPlugins: '发布插件',
difyVersionNotCompatible: '当前 Dify 版本不兼容该插件,其最低版本要求为 {{minimalDifyVersion}}',
auth: {
default: '默认',
custom: '自定义',
setDefault: '设为默认',
useOAuth: '使用 OAuth',
useOAuthAuth: '使用 OAuth 授权',
addOAuth: '添加 OAuth',
setupOAuth: '设置 OAuth 客户端',
useApi: '使用 API Key',
addApi: '添加 API Key',
useApiAuth: 'API Key 授权配置',
useApiAuthDesc: '配置凭据后,工作区内的所有成员在编排应用时都可以使用此工具。',
oauthClientSettings: 'OAuth 客户端设置',
saveOnly: '仅保存',
saveAndAuth: '保存并授权',
authorization: '凭据',
authorizations: '凭据',
authorizationName: '凭据名称',
workspaceDefault: '工作区默认',
authRemoved: '凭据已移除',
clientInfo: '由于未找到此工具提供者的系统客户端密钥,因此需要手动设置,对于 redirect_uri请使用',
oauthClient: 'OAuth 客户端',
},
}
export default translation

View File

@@ -0,0 +1,161 @@
import {
useMutation,
useQuery,
} from '@tanstack/react-query'
import { del, get, post } from './base'
import { useInvalid } from './use-base'
import type {
Credential,
CredentialTypeEnum,
} from '@/app/components/plugins/plugin-auth/types'
import type { FormSchema } from '@/app/components/base/form/types'
const NAME_SPACE = 'plugins-auth'
export const useGetPluginCredentialInfo = (
url: string,
) => {
return useQuery({
enabled: !!url,
queryKey: [NAME_SPACE, 'credential-info', url],
queryFn: () => get<{
supported_credential_types: string[]
credentials: Credential[]
is_oauth_custom_client_enabled: boolean
}>(url),
staleTime: 0,
})
}
export const useInvalidPluginCredentialInfo = (
url: string,
) => {
return useInvalid([NAME_SPACE, 'credential-info', url])
}
export const useSetPluginDefaultCredential = (
url: string,
) => {
return useMutation({
mutationFn: (id: string) => {
return post(url, { body: { id } })
},
})
}
export const useGetPluginCredentialList = (
url: string,
) => {
return useQuery({
queryKey: [NAME_SPACE, 'credential-list', url],
queryFn: () => get(url),
})
}
export const useAddPluginCredential = (
url: string,
) => {
return useMutation({
mutationFn: (params: {
credentials: Record<string, any>
type: CredentialTypeEnum
name?: string
}) => {
return post(url, { body: params })
},
})
}
export const useUpdatePluginCredential = (
url: string,
) => {
return useMutation({
mutationFn: (params: {
credential_id: string
credentials?: Record<string, any>
name?: string
}) => {
return post(url, { body: params })
},
})
}
export const useDeletePluginCredential = (
url: string,
) => {
return useMutation({
mutationFn: (params: { credential_id: string }) => {
return post(url, { body: params })
},
})
}
export const useGetPluginCredentialSchema = (
url: string,
) => {
return useQuery({
queryKey: [NAME_SPACE, 'credential-schema', url],
queryFn: () => get<FormSchema[]>(url),
})
}
export const useGetPluginOAuthUrl = (
url: string,
) => {
return useMutation({
mutationKey: [NAME_SPACE, 'oauth-url', url],
mutationFn: () => {
return get<
{
authorization_url: string
state: string
context_id: string
}>(url)
},
})
}
export const useGetPluginOAuthClientSchema = (
url: string,
) => {
return useQuery({
queryKey: [NAME_SPACE, 'oauth-client-schema', url],
queryFn: () => get<{
schema: FormSchema[]
is_oauth_custom_client_enabled: boolean
is_system_oauth_params_exists?: boolean
client_params?: Record<string, any>
redirect_uri?: string
}>(url),
staleTime: 0,
})
}
export const useInvalidPluginOAuthClientSchema = (
url: string,
) => {
return useInvalid([NAME_SPACE, 'oauth-client-schema', url])
}
export const useSetPluginOAuthCustomClient = (
url: string,
) => {
return useMutation({
mutationFn: (params: {
client_params: Record<string, any>
enable_oauth_custom_client: boolean
}) => {
return post<{ result: string }>(url, { body: params })
},
})
}
export const useDeletePluginOAuthCustomClient = (
url: string,
) => {
return useMutation({
mutationFn: () => {
return del<{ result: string }>(url)
},
})
}

View File

@@ -16,10 +16,11 @@ import {
const NAME_SPACE = 'tools'
const useAllToolProvidersKey = [NAME_SPACE, 'allToolProviders']
export const useAllToolProviders = () => {
export const useAllToolProviders = (enabled = true) => {
return useQuery<Collection[]>({
queryKey: useAllToolProvidersKey,
queryFn: () => get<Collection[]>('/workspaces/current/tool-providers'),
enabled,
})
}

View File

@@ -130,6 +130,7 @@ export type AgentTool = {
enabled: boolean
isDeleted?: boolean
notAuthor?: boolean
credential_id?: string
}
export type ToolItem = {