diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index cc735ae67..b933560a5 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -6,6 +6,7 @@ on: - "main" - "deploy/dev" - "deploy/enterprise" + - "build/**" tags: - "*" diff --git a/api/.env.example b/api/.env.example index 3fe95c44b..b8976e5b1 100644 --- a/api/.env.example +++ b/api/.env.example @@ -5,17 +5,17 @@ SECRET_KEY= # Console API base URL -CONSOLE_API_URL=http://127.0.0.1:5001 -CONSOLE_WEB_URL=http://127.0.0.1:3000 +CONSOLE_API_URL=http://localhost:5001 +CONSOLE_WEB_URL=http://localhost:3000 # Service API base URL -SERVICE_API_URL=http://127.0.0.1:5001 +SERVICE_API_URL=http://localhost:5001 # Web APP base URL -APP_WEB_URL=http://127.0.0.1:3000 +APP_WEB_URL=http://localhost:3000 # Files URL -FILES_URL=http://127.0.0.1:5001 +FILES_URL=http://localhost:5001 # INTERNAL_FILES_URL is used for plugin daemon communication within Docker network. # Set this to the internal Docker service URL for proper plugin file access. @@ -138,8 +138,8 @@ SUPABASE_API_KEY=your-access-key SUPABASE_URL=your-server-url # CORS configuration -WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* +WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,* +CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,* # Vector database configuration # support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone diff --git a/api/commands.py b/api/commands.py index 86769847c..9f933a378 100644 --- a/api/commands.py +++ b/api/commands.py @@ -2,19 +2,22 @@ import base64 import json import logging import secrets -from typing import Optional +from typing import Any, Optional import click from flask import current_app +from pydantic import TypeAdapter from sqlalchemy import select from werkzeug.exceptions import NotFound from configs import dify_config from constants.languages import languages +from core.plugin.entities.plugin import ToolProviderID from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.models.document import Document +from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params from events.app_event import app_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -27,6 +30,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D from models.dataset import Document as DatasetDocument from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel +from models.tools import ToolOAuthSystemClient from services.account_service import AccountService, RegisterService, TenantService from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs from services.plugin.data_migration import PluginDataMigration @@ -1155,3 +1159,49 @@ def remove_orphaned_files_on_storage(force: bool): click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green")) else: click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow")) + + +@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.") +@click.option("--provider", prompt=True, help="Provider name") +@click.option("--client-params", prompt=True, help="Client Params") +def setup_system_tool_oauth_client(provider, client_params): + """ + Setup system tool oauth client + """ + provider_id = ToolProviderID(provider) + provider_name = provider_id.provider_name + plugin_id = provider_id.plugin_id + + try: + # json validate + click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) + client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) + click.echo(click.style("Client params validated successfully.", fg="green")) + + click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) + click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) + oauth_client_params = encrypt_system_oauth_params(client_params_dict) + click.echo(click.style("Client params encrypted successfully.", fg="green")) + except Exception as e: + click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) + return + + deleted_count = ( + db.session.query(ToolOAuthSystemClient) + .filter_by( + provider=provider_name, + plugin_id=plugin_id, + ) + .delete() + ) + if deleted_count > 0: + click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) + + oauth_client = ToolOAuthSystemClient( + provider=provider_name, + plugin_id=plugin_id, + encrypted_oauth_params=oauth_client_params, + ) + db.session.add(oauth_client) + db.session.commit() + click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) diff --git a/api/constants/__init__.py b/api/constants/__init__.py index a84de0a45..9e052320a 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -1,6 +1,7 @@ from configs import dify_config HIDDEN_VALUE = "[__HIDDEN__]" +UNKNOWN_VALUE = "[__UNKNOWN__]" UUID_NIL = "00000000-0000-0000-0000-000000000000" DEFAULT_FILE_NUMBER_LIMITS = 3 diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index df50871a3..e41375e52 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,23 +1,32 @@ import io from urllib.parse import urlparse -from flask import redirect, send_file +from flask import make_response, redirect, request, send_file from flask_login import current_user -from flask_restful import Resource, reqparse -from sqlalchemy.orm import Session +from flask_restful import ( + Resource, + reqparse, +) from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api -from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + enterprise_license_required, + setup_required, +) from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.auth.auth_provider import OAuthClientProvider from core.mcp.error import MCPAuthError, MCPError from core.mcp.mcp_client import MCPClient from core.model_runtime.utils.encoders import jsonable_encoder -from extensions.ext_database import db -from libs.helper import alphanumeric, uuid_value +from core.plugin.entities.plugin import ToolProviderID +from core.plugin.impl.oauth import OAuthHandler +from core.tools.entities.tool_entities import CredentialType +from libs.helper import StrLen, alphanumeric, uuid_value from libs.login import login_required +from services.plugin.oauth_service import OAuthProxyService from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.tools.mcp_tools_mange_service import MCPToolManageService @@ -89,7 +98,7 @@ class ToolBuiltinProviderInfoApi(Resource): user_id = user.id tenant_id = user.current_tenant_id - return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider)) + return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) class ToolBuiltinProviderDeleteApi(Resource): @@ -98,17 +107,47 @@ class ToolBuiltinProviderDeleteApi(Resource): @account_initialization_required def post(self, provider): user = current_user - if not user.is_admin_or_owner: raise Forbidden() + tenant_id = user.current_tenant_id + req = reqparse.RequestParser() + req.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + args = req.parse_args() + + return BuiltinToolManageService.delete_builtin_tool_provider( + tenant_id, + provider, + args["credential_id"], + ) + + +class ToolBuiltinProviderAddApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + user = current_user + user_id = user.id tenant_id = user.current_tenant_id - return BuiltinToolManageService.delete_builtin_tool_provider( - user_id, - tenant_id, - provider, + parser = reqparse.RequestParser() + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json") + parser.add_argument("type", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + if args["type"] not in CredentialType.values(): + raise ValueError(f"Invalid credential type: {args['type']}") + + return BuiltinToolManageService.add_builtin_tool_provider( + user_id=user_id, + tenant_id=tenant_id, + provider=provider, + credentials=args["credentials"], + name=args["name"], + api_type=CredentialType.of(args["type"]), ) @@ -126,19 +165,20 @@ class ToolBuiltinProviderUpdateApi(Resource): tenant_id = user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") args = parser.parse_args() - with Session(db.engine) as session: - result = BuiltinToolManageService.update_builtin_tool_provider( - session=session, - user_id=user_id, - tenant_id=tenant_id, - provider_name=provider, - credentials=args["credentials"], - ) - session.commit() + result = BuiltinToolManageService.update_builtin_tool_provider( + user_id=user_id, + tenant_id=tenant_id, + provider=provider, + credential_id=args["credential_id"], + credentials=args.get("credentials", None), + name=args.get("name", ""), + ) return result @@ -149,9 +189,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): def get(self, provider): tenant_id = current_user.current_tenant_id - return BuiltinToolManageService.get_builtin_tool_provider_credentials( - tenant_id=tenant_id, - provider_name=provider, + return jsonable_encoder( + BuiltinToolManageService.get_builtin_tool_provider_credentials( + tenant_id=tenant_id, + provider_name=provider, + ) ) @@ -344,12 +386,15 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider): + def get(self, provider, credential_type): user = current_user - tenant_id = user.current_tenant_id - return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id) + return jsonable_encoder( + BuiltinToolManageService.list_builtin_provider_credentials_schema( + provider, CredentialType.of(credential_type), tenant_id + ) + ) class ToolApiProviderSchemaApi(Resource): @@ -586,15 +631,12 @@ class ToolApiListApi(Resource): @account_initialization_required def get(self): user = current_user - - user_id = user.id tenant_id = user.current_tenant_id return jsonable_encoder( [ provider.to_dict() for provider in ApiToolManageService.list_api_tools( - user_id, tenant_id, ) ] @@ -631,6 +673,179 @@ class ToolLabelsApi(Resource): return jsonable_encoder(ToolLabelsService.list_tool_labels()) +class ToolPluginOAuthApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + tool_provider = ToolProviderID(provider) + plugin_id = tool_provider.plugin_id + provider_name = tool_provider.provider_name + + # todo check permission + user = current_user + + if not user.is_admin_or_owner: + raise Forbidden() + + tenant_id = user.current_tenant_id + oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider) + if oauth_client_params is None: + raise Forbidden("no oauth available client config found for this tool provider") + + oauth_handler = OAuthHandler() + context_id = OAuthProxyService.create_proxy_context( + user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name + ) + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" + authorization_url_response = oauth_handler.get_authorization_url( + tenant_id=tenant_id, + user_id=user.id, + plugin_id=plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=oauth_client_params, + ) + response = make_response(jsonable_encoder(authorization_url_response)) + response.set_cookie( + "context_id", + context_id, + httponly=True, + samesite="Lax", + max_age=OAuthProxyService.__MAX_AGE__, + ) + return response + + +class ToolOAuthCallback(Resource): + @setup_required + def get(self, provider): + context_id = request.cookies.get("context_id") + if not context_id: + raise Forbidden("context_id not found") + + context = OAuthProxyService.use_proxy_context(context_id) + if context is None: + raise Forbidden("Invalid context_id") + + tool_provider = ToolProviderID(provider) + plugin_id = tool_provider.plugin_id + provider_name = tool_provider.provider_name + user_id, tenant_id = context.get("user_id"), context.get("tenant_id") + + oauth_handler = OAuthHandler() + oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider) + if oauth_client_params is None: + raise Forbidden("no oauth available client config found for this tool provider") + + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" + credentials = oauth_handler.get_credentials( + tenant_id=tenant_id, + user_id=user_id, + plugin_id=plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=oauth_client_params, + request=request, + ).credentials + + if not credentials: + raise Exception("the plugin credentials failed") + + # add credentials to database + BuiltinToolManageService.add_builtin_tool_provider( + user_id=user_id, + tenant_id=tenant_id, + provider=provider, + credentials=dict(credentials), + api_type=CredentialType.OAUTH2, + ) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") + + +class ToolBuiltinProviderSetDefaultApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + parser = reqparse.RequestParser() + parser.add_argument("id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + return BuiltinToolManageService.set_default_provider( + tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] + ) + + +class ToolOAuthCustomClient(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + parser = reqparse.RequestParser() + parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") + parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") + args = parser.parse_args() + + user = current_user + + if not user.is_admin_or_owner: + raise Forbidden() + + return BuiltinToolManageService.save_custom_oauth_client_params( + tenant_id=user.current_tenant_id, + provider=provider, + client_params=args.get("client_params", {}), + enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), + ) + + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + return jsonable_encoder( + BuiltinToolManageService.get_custom_oauth_client_params( + tenant_id=current_user.current_tenant_id, provider=provider + ) + ) + + @setup_required + @login_required + @account_initialization_required + def delete(self, provider): + return jsonable_encoder( + BuiltinToolManageService.delete_custom_oauth_client_params( + tenant_id=current_user.current_tenant_id, provider=provider + ) + ) + + +class ToolBuiltinProviderGetOauthClientSchemaApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + return jsonable_encoder( + BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema( + tenant_id=current_user.current_tenant_id, provider_name=provider + ) + ) + + +class ToolBuiltinProviderGetCredentialInfoApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + tenant_id = current_user.current_tenant_id + + return jsonable_encoder( + BuiltinToolManageService.get_builtin_tool_provider_credential_info( + tenant_id=tenant_id, + provider=provider, + ) + ) + + class ToolProviderMCPApi(Resource): @setup_required @login_required @@ -794,17 +1009,33 @@ class ToolMCPCallbackApi(Resource): # tool provider api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") +# tool oauth +api.add_resource(ToolPluginOAuthApi, "/oauth/plugin//tool/authorization-url") +api.add_resource(ToolOAuthCallback, "/oauth/plugin//tool/callback") +api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin//oauth/custom-client") + # builtin tool provider api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin//info") +api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin//add") api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") +api.add_resource( + ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin//default-credential" +) +api.add_resource( + ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin//credential/info" +) api.add_resource( ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" ) api.add_resource( ToolBuiltinProviderCredentialsSchemaApi, - "/workspaces/current/tool-provider/builtin//credentials_schema", + "/workspaces/current/tool-provider/builtin//credential/schema/", +) +api.add_resource( + ToolBuiltinProviderGetOauthClientSchemaApi, + "/workspaces/current/tool-provider/builtin//oauth/client-schema", ) api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 327e9ce83..5dfe41eb6 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -175,6 +175,7 @@ class PluginInvokeToolApi(Resource): provider=payload.provider, tool_name=payload.tool, tool_parameters=payload.tool_parameters, + credential_id=payload.credential_id, ), ) diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 143a3a51a..a31c1050b 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -16,6 +16,7 @@ class AgentToolEntity(BaseModel): tool_name: str tool_parameters: dict[str, Any] = Field(default_factory=dict) plugin_unique_identifier: str | None = None + credential_id: str | None = None class AgentPromptEntity(BaseModel): diff --git a/api/core/agent/strategy/base.py b/api/core/agent/strategy/base.py index ead81a7a0..a52a1dfd7 100644 --- a/api/core/agent/strategy/base.py +++ b/api/core/agent/strategy/base.py @@ -4,6 +4,7 @@ from typing import Any, Optional from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyParameter +from core.plugin.entities.request import InvokeCredentials class BaseAgentStrategy(ABC): @@ -18,11 +19,12 @@ class BaseAgentStrategy(ABC): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + credentials: Optional[InvokeCredentials] = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent strategy. """ - yield from self._invoke(params, user_id, conversation_id, app_id, message_id) + yield from self._invoke(params, user_id, conversation_id, app_id, message_id, credentials) def get_parameters(self) -> Sequence[AgentStrategyParameter]: """ @@ -38,5 +40,6 @@ class BaseAgentStrategy(ABC): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + credentials: Optional[InvokeCredentials] = None, ) -> Generator[AgentInvokeMessage, None, None]: pass diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py index 4cfcfbf86..04661581a 100644 --- a/api/core/agent/strategy/plugin.py +++ b/api/core/agent/strategy/plugin.py @@ -4,6 +4,7 @@ from typing import Any, Optional from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter from core.agent.strategy.base import BaseAgentStrategy +from core.plugin.entities.request import InvokeCredentials, PluginInvokeContext from core.plugin.impl.agent import PluginAgentClient from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -40,6 +41,7 @@ class PluginAgentStrategy(BaseAgentStrategy): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + credentials: Optional[InvokeCredentials] = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent strategy. @@ -58,4 +60,5 @@ class PluginAgentStrategy(BaseAgentStrategy): conversation_id=conversation_id, app_id=app_id, message_id=message_id, + context=PluginInvokeContext(credentials=credentials or InvokeCredentials()), ) diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index 590b944c0..8887d2500 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -39,6 +39,7 @@ class AgentConfigManager: "provider_id": tool["provider_id"], "tool_name": tool["tool_name"], "tool_parameters": tool.get("tool_parameters", {}), + "credential_id": tool.get("credential_id", None), } agent_tools.append(AgentToolEntity(**agent_tool_properties)) diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py new file mode 100644 index 000000000..48ec3be5c --- /dev/null +++ b/api/core/helper/provider_cache.py @@ -0,0 +1,84 @@ +import json +from abc import ABC, abstractmethod +from json import JSONDecodeError +from typing import Any, Optional + +from extensions.ext_redis import redis_client + + +class ProviderCredentialsCache(ABC): + """Base class for provider credentials cache""" + + def __init__(self, **kwargs): + self.cache_key = self._generate_cache_key(**kwargs) + + @abstractmethod + def _generate_cache_key(self, **kwargs) -> str: + """Generate cache key based on subclass implementation""" + pass + + def get(self) -> Optional[dict]: + """Get cached provider credentials""" + cached_credentials = redis_client.get(self.cache_key) + if cached_credentials: + try: + cached_credentials = cached_credentials.decode("utf-8") + return dict(json.loads(cached_credentials)) + except JSONDecodeError: + return None + return None + + def set(self, config: dict[str, Any]) -> None: + """Cache provider credentials""" + redis_client.setex(self.cache_key, 86400, json.dumps(config)) + + def delete(self) -> None: + """Delete cached provider credentials""" + redis_client.delete(self.cache_key) + + +class SingletonProviderCredentialsCache(ProviderCredentialsCache): + """Cache for tool single provider credentials""" + + def __init__(self, tenant_id: str, provider_type: str, provider_identity: str): + super().__init__( + tenant_id=tenant_id, + provider_type=provider_type, + provider_identity=provider_identity, + ) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + provider_type = kwargs["provider_type"] + identity_name = kwargs["provider_identity"] + identity_id = f"{provider_type}.{identity_name}" + return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}" + + +class ToolProviderCredentialsCache(ProviderCredentialsCache): + """Cache for tool provider credentials""" + + def __init__(self, tenant_id: str, provider: str, credential_id: str): + super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + provider = kwargs["provider"] + credential_id = kwargs["credential_id"] + return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" + + +class NoOpProviderCredentialCache: + """No-op provider credential cache""" + + def get(self) -> Optional[dict]: + """Get cached provider credentials""" + return None + + def set(self, config: dict[str, Any]) -> None: + """Cache provider credentials""" + pass + + def delete(self) -> None: + """Delete cached provider credentials""" + pass diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py deleted file mode 100644 index 2e4a04c57..000000000 --- a/api/core/helper/tool_provider_cache.py +++ /dev/null @@ -1,51 +0,0 @@ -import json -from enum import Enum -from json import JSONDecodeError -from typing import Optional - -from extensions.ext_redis import redis_client - - -class ToolProviderCredentialsCacheType(Enum): - PROVIDER = "tool_provider" - ENDPOINT = "endpoint" - - -class ToolProviderCredentialsCache: - def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): - self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" - - def get(self) -> Optional[dict]: - """ - Get cached model provider credentials. - - :return: - """ - cached_provider_credentials = redis_client.get(self.cache_key) - if cached_provider_credentials: - try: - cached_provider_credentials = cached_provider_credentials.decode("utf-8") - cached_provider_credentials = json.loads(cached_provider_credentials) - except JSONDecodeError: - return None - - return dict(cached_provider_credentials) - else: - return None - - def set(self, credentials: dict) -> None: - """ - Cache model provider credentials. - - :param credentials: provider credentials - :return: - """ - redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) - - def delete(self) -> None: - """ - Delete cached model provider credentials. - - :return: - """ - redis_client.delete(self.cache_key) diff --git a/api/core/plugin/backwards_invocation/encrypt.py b/api/core/plugin/backwards_invocation/encrypt.py index 81a5d033a..213f5c726 100644 --- a/api/core/plugin/backwards_invocation/encrypt.py +++ b/api/core/plugin/backwards_invocation/encrypt.py @@ -1,16 +1,20 @@ +from core.helper.provider_cache import SingletonProviderCredentialsCache from core.plugin.entities.request import RequestInvokeEncrypt -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_provider_encrypter from models.account import Tenant class PluginEncrypter: @classmethod def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: - encrypter = ProviderConfigEncrypter( + encrypter, cache = create_provider_encrypter( tenant_id=tenant.id, config=payload.config, - provider_type=payload.namespace, - provider_identity=payload.identity, + cache=SingletonProviderCredentialsCache( + tenant_id=tenant.id, + provider_type=payload.namespace, + provider_identity=payload.identity, + ), ) if payload.opt == "encrypt": @@ -22,7 +26,7 @@ class PluginEncrypter: "data": encrypter.decrypt(payload.data), } elif payload.opt == "clear": - encrypter.delete_tool_credentials_cache() + cache.delete() return { "data": {}, } diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py index 1d62743f1..06773504d 100644 --- a/api/core/plugin/backwards_invocation/tool.py +++ b/api/core/plugin/backwards_invocation/tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any +from typing import Any, Optional from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.plugin.backwards_invocation.base import BaseBackwardsInvocation @@ -23,6 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): provider: str, tool_name: str, tool_parameters: dict[str, Any], + credential_id: Optional[str] = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke tool @@ -30,7 +31,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): # get tool runtime try: tool_runtime = ToolManager.get_tool_runtime_from_plugin( - tool_type, tenant_id, provider, tool_name, tool_parameters + tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id ) response = ToolEngine.generic_invoke( tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1 diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 89f595ec4..3a783dad3 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -27,6 +27,20 @@ from core.workflow.nodes.question_classifier.entities import ( ) +class InvokeCredentials(BaseModel): + tool_credentials: dict[str, str] = Field( + default_factory=dict, + description="Map of tool provider to credential id, used to store the credential id for the tool provider.", + ) + + +class PluginInvokeContext(BaseModel): + credentials: Optional[InvokeCredentials] = Field( + default_factory=InvokeCredentials, + description="Credentials context for the plugin invocation or backward invocation.", + ) + + class RequestInvokeTool(BaseModel): """ Request to invoke a tool @@ -36,6 +50,7 @@ class RequestInvokeTool(BaseModel): provider: str tool: str tool_parameters: dict + credential_id: Optional[str] = None class BaseRequestInvokeModel(BaseModel): diff --git a/api/core/plugin/impl/agent.py b/api/core/plugin/impl/agent.py index 66b77c748..9575c57ac 100644 --- a/api/core/plugin/impl/agent.py +++ b/api/core/plugin/impl/agent.py @@ -6,6 +6,7 @@ from core.plugin.entities.plugin import GenericProviderID from core.plugin.entities.plugin_daemon import ( PluginAgentProviderEntity, ) +from core.plugin.entities.request import PluginInvokeContext from core.plugin.impl.base import BasePluginClient @@ -83,6 +84,7 @@ class PluginAgentClient(BasePluginClient): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + context: Optional[PluginInvokeContext] = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent with the given tenant, user, plugin, provider, name and parameters. @@ -99,6 +101,7 @@ class PluginAgentClient(BasePluginClient): "conversation_id": conversation_id, "app_id": app_id, "message_id": message_id, + "context": context.model_dump() if context else {}, "data": { "agent_strategy_provider": agent_provider_id.provider_name, "agent_strategy": agent_strategy, diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index b006bf1d4..d73e5d9f9 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -15,27 +15,32 @@ class OAuthHandler(BasePluginClient): user_id: str, plugin_id: str, provider: str, + redirect_uri: str, system_credentials: Mapping[str, Any], ) -> PluginOAuthAuthorizationUrlResponse: - response = self._request_with_plugin_daemon_response_stream( - "POST", - f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", - PluginOAuthAuthorizationUrlResponse, - data={ - "user_id": user_id, - "data": { - "provider": provider, - "system_credentials": system_credentials, + try: + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", + PluginOAuthAuthorizationUrlResponse, + data={ + "user_id": user_id, + "data": { + "provider": provider, + "redirect_uri": redirect_uri, + "system_credentials": system_credentials, + }, }, - }, - headers={ - "X-Plugin-ID": plugin_id, - "Content-Type": "application/json", - }, - ) - for resp in response: - return resp - raise ValueError("No response received from plugin daemon for authorization URL request.") + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") + except Exception as e: + raise ValueError(f"Error getting authorization URL: {e}") def get_credentials( self, @@ -43,6 +48,7 @@ class OAuthHandler(BasePluginClient): user_id: str, plugin_id: str, provider: str, + redirect_uri: str, system_credentials: Mapping[str, Any], request: Request, ) -> PluginOAuthCredentialsResponse: @@ -50,30 +56,33 @@ class OAuthHandler(BasePluginClient): Get credentials from the given request. """ - # encode request to raw http request - raw_request_bytes = self._convert_request_to_raw_data(request) - - response = self._request_with_plugin_daemon_response_stream( - "POST", - f"plugin/{tenant_id}/dispatch/oauth/get_credentials", - PluginOAuthCredentialsResponse, - data={ - "user_id": user_id, - "data": { - "provider": provider, - "system_credentials": system_credentials, - # for json serialization - "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), + try: + # encode request to raw http request + raw_request_bytes = self._convert_request_to_raw_data(request) + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/oauth/get_credentials", + PluginOAuthCredentialsResponse, + data={ + "user_id": user_id, + "data": { + "provider": provider, + "redirect_uri": redirect_uri, + "system_credentials": system_credentials, + # for json serialization + "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), + }, }, - }, - headers={ - "X-Plugin-ID": plugin_id, - "Content-Type": "application/json", - }, - ) - for resp in response: - return resp - raise ValueError("No response received from plugin daemon for authorization URL request.") + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") + except Exception as e: + raise ValueError(f"Error getting credentials: {e}") def _convert_request_to_raw_data(self, request: Request) -> bytes: """ diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 19b26c8fe..04225f95e 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity from core.plugin.impl.base import BasePluginClient -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter class PluginToolManager(BasePluginClient): @@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient): tool_provider: str, tool_name: str, credentials: dict[str, Any], + credential_type: CredentialType, tool_parameters: dict[str, Any], conversation_id: Optional[str] = None, app_id: Optional[str] = None, @@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient): "provider": tool_provider_id.provider_name, "tool": tool_name, "credentials": credentials, + "credential_type": credential_type, "tool_parameters": tool_parameters, }, }, diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index c9e157cb7..ddec7b132 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -4,7 +4,7 @@ from openai import BaseModel from pydantic import Field from core.app.entities.app_invoke_entities import InvokeFrom -from core.tools.entities.tool_entities import ToolInvokeFrom +from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom class ToolRuntime(BaseModel): @@ -17,6 +17,7 @@ class ToolRuntime(BaseModel): invoke_from: Optional[InvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None credentials: dict[str, Any] = Field(default_factory=dict) + credential_type: CredentialType = Field(default=CredentialType.API_KEY) runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index cf75bd3d7..a70ded9ef 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool -from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType +from core.tools.entities.tool_entities import ( + CredentialType, + OAuthSchema, + ToolEntity, + ToolProviderEntity, + ToolProviderType, +) from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( ToolProviderNotFoundError, @@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController): credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {}) credentials_schema.append(credential_dict) + oauth_schema = None + if provider_yaml.get("oauth_schema", None) is not None: + oauth_schema = OAuthSchema( + client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []), + credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []), + ) + super().__init__( entity=ToolProviderEntity( identity=provider_yaml["identity"], credentials_schema=credentials_schema, + oauth_schema=oauth_schema, ), ) @@ -97,10 +111,39 @@ class BuiltinToolProviderController(ToolProviderController): :return: the credentials schema """ - if not self.entity.credentials_schema: - return [] + return self.get_credentials_schema_by_type(CredentialType.API_KEY.value) - return self.entity.credentials_schema.copy() + def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: + """ + returns the credentials schema of the provider + + :param credential_type: the type of the credential + :return: the credentials schema of the provider + """ + if credential_type == CredentialType.OAUTH2.value: + return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] + if credential_type == CredentialType.API_KEY.value: + return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] + raise ValueError(f"Invalid credential type: {credential_type}") + + def get_oauth_client_schema(self) -> list[ProviderConfig]: + """ + returns the oauth client schema of the provider + + :return: the oauth client schema + """ + return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else [] + + def get_supported_credential_types(self) -> list[str]: + """ + returns the credential support type of the provider + """ + types = [] + if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0: + types.append(CredentialType.API_KEY.value) + if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0: + types.append(CredentialType.OAUTH2.value) + return types def get_tools(self) -> list[BuiltinTool]: """ @@ -123,7 +166,11 @@ class BuiltinToolProviderController(ToolProviderController): :return: whether the provider needs credentials """ - return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0 + return ( + self.entity.credentials_schema is not None + and len(self.entity.credentials_schema) != 0 + or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0) + ) @property def provider_type(self) -> ToolProviderType: diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 90134ba71..27ce96b90 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, field_validator from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import CredentialType, ToolProviderType class ToolApiEntity(BaseModel): @@ -87,3 +87,22 @@ class ToolProviderApiEntity(BaseModel): def optional_field(self, key: str, value: Any) -> dict: """Return dict with key-value if value is truthy, empty dict otherwise.""" return {key: value} if value else {} + + +class ToolProviderCredentialApiEntity(BaseModel): + id: str = Field(description="The unique id of the credential") + name: str = Field(description="The name of the credential") + provider: str = Field(description="The provider of the credential") + credential_type: CredentialType = Field(description="The type of the credential") + is_default: bool = Field( + default=False, description="Whether the credential is the default credential for the provider in the workspace" + ) + credentials: dict = Field(description="The credentials of the provider") + + +class ToolProviderCredentialInfoApiEntity(BaseModel): + supported_credential_types: list[str] = Field(description="The supported credential types of the provider") + is_oauth_custom_client_enabled: bool = Field( + default=False, description="Whether the OAuth custom client is enabled for the provider" + ) + credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider") diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 64568a8ed..5377cbbb6 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -370,10 +370,18 @@ class ToolEntity(BaseModel): return v or [] +class OAuthSchema(BaseModel): + client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client") + credentials_schema: list[ProviderConfig] = Field( + default_factory=list, description="The schema of the OAuth credentials" + ) + + class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity plugin_id: Optional[str] = None credentials_schema: list[ProviderConfig] = Field(default_factory=list) + oauth_schema: Optional[OAuthSchema] = None class ToolProviderEntityWithPlugin(ToolProviderEntity): @@ -453,6 +461,7 @@ class ToolSelector(BaseModel): options: Optional[list[PluginParameterOption]] = None provider_id: str = Field(..., description="The id of the provider") + credential_id: Optional[str] = Field(default=None, description="The id of the credential") tool_name: str = Field(..., description="The name of the tool") tool_description: str = Field(..., description="The description of the tool") tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") @@ -460,3 +469,36 @@ class ToolSelector(BaseModel): def to_plugin_parameter(self) -> dict[str, Any]: return self.model_dump() + + +class CredentialType(enum.StrEnum): + API_KEY = "api-key" + OAUTH2 = "oauth2" + + def get_name(self): + if self == CredentialType.API_KEY: + return "API KEY" + elif self == CredentialType.OAUTH2: + return "AUTH" + else: + return self.value.replace("-", " ").upper() + + def is_editable(self): + return self == CredentialType.API_KEY + + def is_validate_allowed(self): + return self == CredentialType.API_KEY + + @classmethod + def values(cls): + return [item.value for item in cls] + + @classmethod + def of(cls, credential_type: str) -> "CredentialType": + type_name = credential_type.lower() + if type_name == "api-key": + return cls.API_KEY + elif type_name == "oauth2": + return cls.OAUTH2 + else: + raise ValueError(f"Invalid credential type: {credential_type}") diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index d21e3d7d1..aef2677c3 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -44,6 +44,7 @@ class PluginTool(Tool): tool_provider=self.entity.identity.provider, tool_name=self.entity.identity.name, credentials=self.runtime.credentials, + credential_type=self.runtime.credential_type, tool_parameters=tool_parameters, conversation_id=conversation_id, app_id=app_id, diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 22a9853b4..d61856a8f 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from yarl import URL import contexts +from core.helper.provider_cache import ToolProviderCredentialsCache from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_provider import ToolProviderController @@ -17,6 +18,7 @@ from core.tools.mcp_tool.provider import MCPToolProviderController from core.tools.mcp_tool.tool import MCPTool from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.tool import PluginTool +from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.workflow.entities.variable_pool import VariablePool from services.tools.mcp_tools_mange_service import MCPToolManageService @@ -24,7 +26,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService if TYPE_CHECKING: from core.workflow.nodes.tool.entities import ToolEntity - from configs import dify_config from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -41,16 +42,17 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( ApiProviderAuthType, + CredentialType, ToolInvokeFrom, ToolParameter, ToolProviderType, ) -from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError +from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ( - ProviderConfigEncrypter, ToolParameterConfigurationManager, ) +from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider @@ -68,8 +70,11 @@ class ToolManager: @classmethod def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: """ + get the hardcoded provider + """ + if len(cls._hardcoded_providers) == 0: # init the builtin providers cls.load_hardcoded_providers_cache() @@ -113,7 +118,12 @@ class ToolManager: contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(Lock()) + plugin_tool_providers = contexts.plugin_tool_providers.get() + if provider in plugin_tool_providers: + return plugin_tool_providers[provider] + with contexts.plugin_tool_providers_lock.get(): + # double check plugin_tool_providers = contexts.plugin_tool_providers.get() if provider in plugin_tool_providers: return plugin_tool_providers[provider] @@ -131,25 +141,7 @@ class ToolManager: ) plugin_tool_providers[provider] = controller - - return controller - - @classmethod - def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None: - """ - get the builtin tool - - :param provider: the name of the provider - :param tool_name: the name of the tool - :param tenant_id: the id of the tenant - :return: the provider, the tool - """ - provider_controller = cls.get_builtin_provider(provider, tenant_id) - tool = provider_controller.get_tool(tool_name) - if tool is None: - raise ToolNotFoundError(f"tool {tool_name} not found") - - return tool + return controller @classmethod def get_tool_runtime( @@ -160,6 +152,7 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, + credential_id: Optional[str] = None, ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]: """ get the tool runtime @@ -170,6 +163,7 @@ class ToolManager: :param tenant_id: the tenant id :param invoke_from: invoke from :param tool_invoke_from: the tool invoke from + :param credential_id: the credential id :return: the tool """ @@ -193,49 +187,70 @@ class ToolManager: ) ), ) - + builtin_provider = None if isinstance(provider_controller, PluginToolProviderController): provider_id_entity = ToolProviderID(provider_id) - # get credentials - builtin_provider: BuiltinToolProvider | None = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == str(provider_id_entity)) - | (BuiltinToolProvider.provider == provider_id_entity.provider_name), - ) - .first() - ) + # get specific credentials + if is_valid_uuid(credential_id): + try: + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() + ) + except Exception as e: + builtin_provider = None + logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True) + # if the provider has been deleted, raise an error + if builtin_provider is None: + raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}") + # fallback to the default provider if builtin_provider is None: - raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + # use the default provider + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == str(provider_id_entity)) + | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .first() + ) + if builtin_provider is None: + raise ToolProviderNotFoundError(f"no default provider for {provider_id}") else: builtin_provider = ( db.session.query(BuiltinToolProvider) .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) .first() ) if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - # decrypt the credentials - credentials = builtin_provider.credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type) + ], + cache=ToolProviderCredentialsCache( + tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id + ), ) - - decrypted_credentials = tool_configuration.decrypt(credentials) - return cast( BuiltinTool, builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, - credentials=decrypted_credentials, + credentials=encrypter.decrypt(builtin_provider.credentials), + credential_type=CredentialType.of(builtin_provider.credential_type), runtime_parameters={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -245,22 +260,16 @@ class ToolManager: elif provider_type == ToolProviderType.API: api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) - - # decrypt the credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], - provider_type=api_provider.provider_type.value, - provider_identity=api_provider.entity.identity.name, + controller=api_provider, ) - decrypted_credentials = tool_configuration.decrypt(credentials) - return cast( ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, - credentials=decrypted_credentials, + credentials=encrypter.decrypt(credentials), invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, ) @@ -320,6 +329,7 @@ class ToolManager: tenant_id=tenant_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.AGENT, + credential_id=agent_tool.credential_id, ) runtime_parameters = {} parameters = tool_entity.get_merged_runtime_parameters() @@ -362,6 +372,7 @@ class ToolManager: tenant_id=tenant_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, + credential_id=workflow_tool.credential_id, ) parameters = tool_runtime.get_merged_runtime_parameters() @@ -391,6 +402,7 @@ class ToolManager: provider: str, tool_name: str, tool_parameters: dict[str, Any], + credential_id: Optional[str] = None, ) -> Tool: """ get tool runtime from plugin @@ -402,6 +414,7 @@ class ToolManager: tenant_id=tenant_id, invoke_from=InvokeFrom.SERVICE_API, tool_invoke_from=ToolInvokeFrom.PLUGIN, + credential_id=credential_id, ) runtime_parameters = {} parameters = tool_entity.get_merged_runtime_parameters() @@ -551,6 +564,22 @@ class ToolManager: return cls._builtin_tools_labels[tool_name] + @classmethod + def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]: + """ + list all the builtin providers + """ + # according to multi credentials, select the one with is_default=True first, then created_at oldest + # for compatibility with old version + sql = """ + SELECT DISTINCT ON (tenant_id, provider) id + FROM tool_builtin_providers + WHERE tenant_id = :tenant_id + ORDER BY tenant_id, provider, is_default DESC, created_at DESC + """ + ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] + return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all() + @classmethod def list_providers_from_api( cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral @@ -565,21 +594,13 @@ class ToolManager: with db.session.no_autoflush: if "builtin" in filters: - # get builtin providers builtin_providers = cls.list_builtin_providers(tenant_id) - # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = ( - db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() - ) - - # rewrite db_builtin_providers - for db_provider in db_builtin_providers: - tool_provider_id = str(ToolProviderID(db_provider.provider)) - db_provider.provider = tool_provider_id - - def find_db_builtin_provider(provider): - return next((x for x in db_builtin_providers if x.provider == provider), None) + # key: provider name, value: provider + db_builtin_providers = { + str(ToolProviderID(provider.provider)): provider + for provider in cls.list_default_builtin_providers(tenant_id) + } # append builtin providers for provider in builtin_providers: @@ -591,10 +612,9 @@ class ToolManager: name_func=lambda x: x.identity.name, ): continue - user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, - db_provider=find_db_builtin_provider(provider.entity.identity.name), + db_provider=db_builtin_providers.get(provider.entity.identity.name), decrypt_credentials=False, ) @@ -604,7 +624,6 @@ class ToolManager: result_providers[f"builtin_provider.{user_provider.name}"] = user_provider # get db api providers - if "api" in filters: db_api_providers: list[ApiToolProvider] = ( db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() @@ -764,15 +783,12 @@ class ToolManager: auth_type, ) # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], - provider_type=controller.provider_type.value, - provider_identity=controller.entity.identity.name, + controller=controller, ) - decrypted_credentials = tool_configuration.decrypt(credentials) - masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials)) try: icon = json.loads(provider_obj.icon) diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 251fedf56..aceba6e69 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,12 +1,8 @@ from copy import deepcopy from typing import Any -from pydantic import BaseModel - -from core.entities.provider_entities import BasicProviderConfig from core.helper import encrypter from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType -from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ( ToolParameter, @@ -14,110 +10,6 @@ from core.tools.entities.tool_entities import ( ) -class ProviderConfigEncrypter(BaseModel): - tenant_id: str - config: list[BasicProviderConfig] - provider_type: str - provider_identity: str - - def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: - """ - deep copy data - """ - return deepcopy(data) - - def encrypt(self, data: dict[str, str]) -> dict[str, str]: - """ - encrypt tool credentials with tenant id - - return a deep copy of credentials with encrypted values - """ - data = self._deep_copy(data) - - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") - data[field_name] = encrypted - - return data - - def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: - """ - mask tool credentials - - return a deep copy of credentials with masked values - """ - data = self._deep_copy(data) - - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - if len(data[field_name]) > 6: - data[field_name] = ( - data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] - ) - else: - data[field_name] = "*" * len(data[field_name]) - - return data - - def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]: - """ - decrypt tool credentials with tenant id - - return a deep copy of credentials with decrypted values - """ - if use_cache: - cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f"{self.provider_type}.{self.provider_identity}", - cache_type=ToolProviderCredentialsCacheType.PROVIDER, - ) - cached_credentials = cache.get() - if cached_credentials: - return cached_credentials - data = self._deep_copy(data) - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - try: - # if the value is None or empty string, skip decrypt - if not data[field_name]: - continue - - data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) - except Exception: - pass - - if use_cache: - cache.set(data) - return data - - def delete_tool_credentials_cache(self): - cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f"{self.provider_type}.{self.provider_identity}", - cache_type=ToolProviderCredentialsCacheType.PROVIDER, - ) - cache.delete() - - class ToolParameterConfigurationManager: """ Tool parameter configuration manager diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py new file mode 100644 index 000000000..5fdfd3b9d --- /dev/null +++ b/api/core/tools/utils/encryption.py @@ -0,0 +1,142 @@ +from copy import deepcopy +from typing import Any, Optional, Protocol + +from core.entities.provider_entities import BasicProviderConfig +from core.helper import encrypter +from core.helper.provider_cache import SingletonProviderCredentialsCache +from core.tools.__base.tool_provider import ToolProviderController + + +class ProviderConfigCache(Protocol): + """ + Interface for provider configuration cache operations + """ + + def get(self) -> Optional[dict]: + """Get cached provider configuration""" + ... + + def set(self, config: dict[str, Any]) -> None: + """Cache provider configuration""" + ... + + def delete(self) -> None: + """Delete cached provider configuration""" + ... + + +class ProviderConfigEncrypter: + tenant_id: str + config: list[BasicProviderConfig] + provider_config_cache: ProviderConfigCache + + def __init__( + self, + tenant_id: str, + config: list[BasicProviderConfig], + provider_config_cache: ProviderConfigCache, + ): + self.tenant_id = tenant_id + self.config = config + self.provider_config_cache = provider_config_cache + + def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: + """ + deep copy data + """ + return deepcopy(data) + + def encrypt(self, data: dict[str, str]) -> dict[str, str]: + """ + encrypt tool credentials with tenant id + + return a deep copy of credentials with encrypted values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") + data[field_name] = encrypted + + return data + + def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: + """ + mask tool credentials + + return a deep copy of credentials with masked values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + if len(data[field_name]) > 6: + data[field_name] = ( + data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] + ) + else: + data[field_name] = "*" * len(data[field_name]) + + return data + + def decrypt(self, data: dict[str, str]) -> dict[str, Any]: + """ + decrypt tool credentials with tenant id + + return a deep copy of credentials with decrypted values + """ + cached_credentials = self.provider_config_cache.get() + if cached_credentials: + return cached_credentials + + data = self._deep_copy(data) + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + try: + # if the value is None or empty string, skip decrypt + if not data[field_name]: + continue + + data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) + except Exception: + pass + + self.provider_config_cache.set(data) + return data + + +def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache): + return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache + + +def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController): + cache = SingletonProviderCredentialsCache( + tenant_id=tenant_id, + provider_type=controller.provider_type.value, + provider_identity=controller.entity.identity.name, + ) + encrypt = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], + provider_config_cache=cache, + ) + return encrypt, cache diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_oauth_encryption.py new file mode 100644 index 000000000..f3c946b95 --- /dev/null +++ b/api/core/tools/utils/system_oauth_encryption.py @@ -0,0 +1,187 @@ +import base64 +import hashlib +import logging +from collections.abc import Mapping +from typing import Any, Optional + +from Crypto.Cipher import AES +from Crypto.Random import get_random_bytes +from Crypto.Util.Padding import pad, unpad +from pydantic import TypeAdapter + +from configs import dify_config + +logger = logging.getLogger(__name__) + + +class OAuthEncryptionError(Exception): + """OAuth encryption/decryption specific error""" + + pass + + +class SystemOAuthEncrypter: + """ + A simple OAuth parameters encrypter using AES-CBC encryption. + + This class provides methods to encrypt and decrypt OAuth parameters + using AES-CBC mode with a key derived from the application's SECRET_KEY. + """ + + def __init__(self, secret_key: Optional[str] = None): + """ + Initialize the OAuth encrypter. + + Args: + secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY + + Raises: + ValueError: If SECRET_KEY is not configured or empty + """ + secret_key = secret_key or dify_config.SECRET_KEY or "" + + # Generate a fixed 256-bit key using SHA-256 + self.key = hashlib.sha256(secret_key.encode()).digest() + + def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str: + """ + Encrypt OAuth parameters. + + Args: + oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} + + Returns: + Base64-encoded encrypted string + + Raises: + OAuthEncryptionError: If encryption fails + ValueError: If oauth_params is invalid + """ + + try: + # Generate random IV (16 bytes) + iv = get_random_bytes(16) + + # Create AES cipher (CBC mode) + cipher = AES.new(self.key, AES.MODE_CBC, iv) + + # Encrypt data + padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size) + encrypted_data = cipher.encrypt(padded_data) + + # Combine IV and encrypted data + combined = iv + encrypted_data + + # Return base64 encoded string + return base64.b64encode(combined).decode() + + except Exception as e: + raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e + + def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]: + """ + Decrypt OAuth parameters. + + Args: + encrypted_data: Base64-encoded encrypted string + + Returns: + Decrypted OAuth parameters dictionary + + Raises: + OAuthEncryptionError: If decryption fails + ValueError: If encrypted_data is invalid + """ + if not isinstance(encrypted_data, str): + raise ValueError("encrypted_data must be a string") + + if not encrypted_data: + raise ValueError("encrypted_data cannot be empty") + + try: + # Base64 decode + combined = base64.b64decode(encrypted_data) + + # Check minimum length (IV + at least one AES block) + if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data + raise ValueError("Invalid encrypted data format") + + # Separate IV and encrypted data + iv = combined[:16] + encrypted_data_bytes = combined[16:] + + # Create AES cipher + cipher = AES.new(self.key, AES.MODE_CBC, iv) + + # Decrypt data + decrypted_data = cipher.decrypt(encrypted_data_bytes) + unpadded_data = unpad(decrypted_data, AES.block_size) + + # Parse JSON + oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) + + if not isinstance(oauth_params, dict): + raise ValueError("Decrypted data is not a valid dictionary") + + return oauth_params + + except Exception as e: + raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e + + +# Factory function for creating encrypter instances +def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter: + """ + Create an OAuth encrypter instance. + + Args: + secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY + + Returns: + SystemOAuthEncrypter instance + """ + return SystemOAuthEncrypter(secret_key=secret_key) + + +# Global encrypter instance (for backward compatibility) +_oauth_encrypter: Optional[SystemOAuthEncrypter] = None + + +def get_system_oauth_encrypter() -> SystemOAuthEncrypter: + """ + Get the global OAuth encrypter instance. + + Returns: + SystemOAuthEncrypter instance + """ + global _oauth_encrypter + if _oauth_encrypter is None: + _oauth_encrypter = SystemOAuthEncrypter() + return _oauth_encrypter + + +# Convenience functions for backward compatibility +def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str: + """ + Encrypt OAuth parameters using the global encrypter. + + Args: + oauth_params: OAuth parameters dictionary + + Returns: + Base64-encoded encrypted string + """ + return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params) + + +def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]: + """ + Decrypt OAuth parameters using the global encrypter. + + Args: + encrypted_data: Base64-encoded encrypted string + + Returns: + Decrypted OAuth parameters dictionary + """ + return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data) diff --git a/api/core/tools/utils/uuid_utils.py b/api/core/tools/utils/uuid_utils.py index 3046c08c8..bdcc33259 100644 --- a/api/core/tools/utils/uuid_utils.py +++ b/api/core/tools/utils/uuid_utils.py @@ -1,7 +1,9 @@ import uuid -def is_valid_uuid(uuid_str: str) -> bool: +def is_valid_uuid(uuid_str: str | None) -> bool: + if uuid_str is None or len(uuid_str) == 0: + return False try: uuid.UUID(uuid_str) return True diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 678b99d54..ce67197a5 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -4,6 +4,7 @@ from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast from packaging.version import Version +from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session @@ -13,10 +14,16 @@ from core.agent.strategy.plugin import PluginAgentStrategy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.plugin.entities.request import InvokeCredentials from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.plugin import PluginInstaller from core.provider_manager import ProviderManager -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType +from core.tools.entities.tool_entities import ( + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) from core.tools.tool_manager import ToolManager from core.variables.segments import StringSegment from core.workflow.entities.node_entities import NodeRunResult @@ -84,6 +91,7 @@ class AgentNode(ToolNode): for_log=True, strategy=strategy, ) + credentials = self._generate_credentials(parameters=parameters) # get conversation id conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) @@ -94,6 +102,7 @@ class AgentNode(ToolNode): user_id=self.user_id, app_id=self.app_id, conversation_id=conversation_id.text if conversation_id else None, + credentials=credentials, ) except Exception as e: yield RunCompletedEvent( @@ -246,6 +255,7 @@ class AgentNode(ToolNode): tool_name=tool.get("tool_name", ""), tool_parameters=parameters, plugin_unique_identifier=tool.get("plugin_unique_identifier", None), + credential_id=tool.get("credential_id", None), ) extra = tool.get("extra", {}) @@ -276,6 +286,7 @@ class AgentNode(ToolNode): { **tool_runtime.entity.model_dump(mode="json"), "runtime_parameters": runtime_parameters, + "credential_id": tool.get("credential_id", None), "provider_type": provider_type.value, } ) @@ -305,6 +316,27 @@ class AgentNode(ToolNode): return result + def _generate_credentials( + self, + parameters: dict[str, Any], + ) -> InvokeCredentials: + """ + Generate credentials based on the given agent parameters. + """ + + credentials = InvokeCredentials() + + # generate credentials for tools selector + credentials.tool_credentials = {} + for tool in parameters.get("tools", []): + if tool.get("credential_id"): + try: + identity = ToolIdentity.model_validate(tool.get("identity", {})) + credentials.tool_credentials[identity.provider] = tool.get("credential_id", None) + except ValidationError: + continue + return credentials + @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 691f6e019..88c5160d1 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -14,6 +14,7 @@ class ToolEntity(BaseModel): tool_name: str tool_label: str # redundancy tool_configurations: dict[str, Any] + credential_id: str | None = None plugin_unique_identifier: str | None = None # redundancy @field_validator("tool_configurations", mode="before") diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 249bd1442..6c9fc0bf1 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -20,6 +20,7 @@ def handle(sender, **kwargs): provider_id=tool_entity.provider_id, tool_name=tool_entity.tool_name, tenant_id=app.tenant_id, + credential_id=tool_entity.credential_id, ) manager = ToolParameterConfigurationManager( tenant_id=app.tenant_id, diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index ddc2158a0..600e336c1 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -18,6 +18,7 @@ def init_app(app: DifyApp): reset_email, reset_encrypt_key_pair, reset_password, + setup_system_tool_oauth_client, upgrade_db, vdb_migrate, ) @@ -40,6 +41,7 @@ def init_app(app: DifyApp): clear_free_plan_tenant_expired_logs, clear_orphaned_file_records, remove_orphaned_files_on_storage, + setup_system_tool_oauth_client, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/migrations/versions/2025_05_15_1635-16081485540c_.py b/api/migrations/versions/2025_05_15_1635-16081485540c_.py new file mode 100644 index 000000000..f55730bfb --- /dev/null +++ b/api/migrations/versions/2025_05_15_1635-16081485540c_.py @@ -0,0 +1,41 @@ +"""empty message + +Revision ID: 16081485540c +Revises: d28f2004b072 +Create Date: 2025-05-15 16:35:39.113777 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '16081485540c' +down_revision = '2adcbe1f5dfb' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tenant_plugin_auto_upgrade_strategies', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False), + sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False), + sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False), + sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False), + sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'), + sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tenant_plugin_auto_upgrade_strategies') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_06_1424-4474872b0ee6_workflow_draft_varaibles_add_node_execution_id.py b/api/migrations/versions/2025_06_06_1424-4474872b0ee6_workflow_draft_varaibles_add_node_execution_id.py index d7a5d116c..47ac27511 100644 --- a/api/migrations/versions/2025_06_06_1424-4474872b0ee6_workflow_draft_varaibles_add_node_execution_id.py +++ b/api/migrations/versions/2025_06_06_1424-4474872b0ee6_workflow_draft_varaibles_add_node_execution_id.py @@ -12,7 +12,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '4474872b0ee6' -down_revision = '2adcbe1f5dfb' +down_revision = '16081485540c' branch_labels = None depends_on = None diff --git a/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py new file mode 100644 index 000000000..df4fbf0a0 --- /dev/null +++ b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py @@ -0,0 +1,62 @@ +"""tool oauth + +Revision ID: 71f5020c6470 +Revises: 4474872b0ee6 +Create Date: 2025-06-24 17:05:43.118647 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '71f5020c6470' +down_revision = '1c9ba48be8e4' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_oauth_system_clients', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('plugin_id', sa.String(length=512), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('encrypted_oauth_params', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'), + sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx') + ) + op.create_table('tool_oauth_tenant_clients', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('plugin_id', sa.String(length=512), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('encrypted_oauth_params', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'), + sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client') + ) + + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False)) + batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False)) + batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') + batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name']) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') + batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider']) + batch_op.drop_column('credential_type') + batch_op.drop_column('is_default') + batch_op.drop_column('name') + + op.drop_table('tool_oauth_tenant_clients') + op.drop_table('tool_oauth_system_clients') + # ### end Alembic commands ### diff --git a/api/models/tools.py b/api/models/tools.py index 9d2c3baea..7c8b5853b 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -21,6 +21,43 @@ from .model import Account, App, Tenant from .types import StringUUID +# system level tool oauth client params (client_id, client_secret, etc.) +class ToolOAuthSystemClient(Base): + __tablename__ = "tool_oauth_system_clients" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"), + db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + # oauth params of the tool provider + encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + + +# tenant level tool oauth client params (client_id, client_secret, etc.) +class ToolOAuthTenantClient(Base): + __tablename__ = "tool_oauth_tenant_clients" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"), + db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # tenant id + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + # oauth params of the tool provider + encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + + @property + def oauth_params(self) -> dict: + return cast(dict, json.loads(self.encrypted_oauth_params or "{}")) + + class BuiltinToolProvider(Base): """ This table stores the tool provider information for built-in tools for each tenant. @@ -29,12 +66,14 @@ class BuiltinToolProvider(Base): __tablename__ = "tool_builtin_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), - # one tenant can only have one tool provider with the same name - db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"), + db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"), ) # id of the tool provider id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + name: Mapped[str] = mapped_column( + db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying") + ) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # who created this tool provider @@ -49,6 +88,11 @@ class BuiltinToolProvider(Base): updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) + is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + # credential type, e.g., "api-key", "oauth2" + credential_type: Mapped[str] = mapped_column( + db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") + ) @property def credentials(self) -> dict: @@ -68,7 +112,7 @@ class ApiToolProvider(Base): id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider - name = db.Column(db.String(255), nullable=False) + name = db.Column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) # icon icon = db.Column(db.String(255), nullable=False) # original schema @@ -281,18 +325,19 @@ class MCPToolProvider(Base): @property def decrypted_credentials(self) -> dict: + from core.helper.provider_cache import NoOpProviderCredentialCache from core.tools.mcp_tool.provider import MCPToolProviderController - from core.tools.utils.configuration import ProviderConfigEncrypter + from core.tools.utils.encryption import create_provider_encrypter provider_controller = MCPToolProviderController._from_db(self) - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_provider_encrypter( tenant_id=self.tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.provider_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + cache=NoOpProviderCredentialCache(), ) - return tool_configuration.decrypt(self.credentials, use_cache=False) + + return encrypter.decrypt(self.credentials) # type: ignore class ToolModelInvoke(Base): diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 20257fa34..08e13c588 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -575,13 +575,26 @@ class AppDslService: raise ValueError("Missing draft workflow configuration, please check.") workflow_dict = workflow.to_dict(include_secret=include_secret) + # TODO: refactor: we need a better way to filter workspace related data from nodes for node in workflow_dict.get("graph", {}).get("nodes", []): - if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: - dataset_ids = node["data"].get("dataset_ids", []) - node["data"]["dataset_ids"] = [ + node_data = node.get("data", {}) + if not node_data: + continue + data_type = node_data.get("type", "") + if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value: + dataset_ids = node_data.get("dataset_ids", []) + node_data["dataset_ids"] = [ cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id) for dataset_id in dataset_ids ] + # filter credential id from tool node + if not include_secret and data_type == NodeType.TOOL.value: + node_data.pop("credential_id", None) + # filter credential id from agent node + if not include_secret and data_type == NodeType.AGENT.value: + for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []): + tool.pop("credential_id", None) + export_data["workflow"] = workflow_dict dependencies = cls._extract_dependencies_from_workflow(workflow) export_data["dependencies"] = [ @@ -602,7 +615,15 @@ class AppDslService: if not app_model_config: raise ValueError("Missing app configuration, please check.") - export_data["model_config"] = app_model_config.to_dict() + model_config = app_model_config.to_dict() + + # TODO: refactor: we need a better way to filter workspace related data from model config + # filter credential id from model config + for tool in model_config.get("agent_mode", {}).get("tools", []): + tool.pop("credential_id", None) + + export_data["model_config"] = model_config + dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict()) export_data["dependencies"] = [ jsonable_encoder(d.model_dump()) diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py index 393213c0e..a1c5639e0 100644 --- a/api/services/plugin/plugin_parameter_service.py +++ b/api/services/plugin/plugin_parameter_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from core.plugin.entities.parameters import PluginParameterOption from core.plugin.impl.dynamic_select import DynamicSelectClient from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_tool_provider_encrypter from extensions.ext_database import db from models.tools import BuiltinToolProvider @@ -38,11 +38,9 @@ class PluginParameterService: case "tool": provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) # check if credentials are required @@ -63,7 +61,7 @@ class PluginParameterService: if db_record is None: raise ValueError(f"Builtin provider {provider} not found when fetching credentials") - credentials = tool_configuration.decrypt(db_record.credentials) + credentials = encrypter.decrypt(db_record.credentials) case _: raise ValueError(f"Invalid provider type: {provider_type}") diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index 0f22afd8d..0a5bc44b6 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -196,6 +196,17 @@ class PluginService: manager = PluginInstaller() return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier) + @staticmethod + def is_plugin_verified(tenant_id: str, plugin_unique_identifier: str) -> bool: + """ + Check if the plugin is verified + """ + manager = PluginInstaller() + try: + return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier).verified + except Exception: + return False + @staticmethod def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]: """ diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 6f848d49c..80badf233 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db from models.tools import ApiToolProvider @@ -164,15 +164,11 @@ class ApiToolManageService: provider_controller.load_bundled_tools(tool_bundles) # encrypt credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) - - encrypted_credentials = tool_configuration.encrypt(credentials) - db_provider.credentials_str = json.dumps(encrypted_credentials) + db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) db.session.add(db_provider) db.session.commit() @@ -297,28 +293,26 @@ class ApiToolManageService: provider_controller.load_bundled_tools(tool_bundles) # get original credentials if exists - tool_configuration = ProviderConfigEncrypter( + encrypter, cache = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) - original_credentials = tool_configuration.decrypt(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + original_credentials = encrypter.decrypt(provider.credentials) + masked_credentials = encrypter.mask_tool_credentials(original_credentials) # check if the credential has changed, save the original credential for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: credentials[name] = original_credentials[name] - credentials = tool_configuration.encrypt(credentials) + credentials = encrypter.encrypt(credentials) provider.credentials_str = json.dumps(credentials) db.session.add(provider) db.session.commit() # delete cache - tool_configuration.delete_tool_credentials_cache() + cache.delete() # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) @@ -416,15 +410,13 @@ class ApiToolManageService: # decrypt credentials if db_provider.id: - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) - decrypted_credentials = tool_configuration.decrypt(credentials) + decrypted_credentials = encrypter.decrypt(credentials) # check if the credential has changed, save the original credential - masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials) for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: credentials[name] = decrypted_credentials[name] @@ -446,7 +438,7 @@ class ApiToolManageService: return {"result": result or "empty response"} @staticmethod - def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: + def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]: """ list api tools """ @@ -474,7 +466,7 @@ class ApiToolManageService: for tool in tools or []: user_provider.tools.append( ToolTransformService.convert_tool_entity_to_api_entity( - tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels + tenant_id=tenant_id, tool=tool, labels=labels ) ) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 58a4b2f17..430575b53 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -1,28 +1,84 @@ import json import logging +import re +from collections.abc import Mapping from pathlib import Path +from typing import Any, Optional from sqlalchemy.orm import Session from configs import dify_config +from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper.position_helper import is_filtered -from core.model_runtime.utils.encoders import jsonable_encoder +from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.plugin.entities.plugin import ToolProviderID -from core.plugin.impl.exc import PluginDaemonClientSideError +from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity -from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError +from core.tools.entities.api_entities import ( + ToolApiEntity, + ToolProviderApiEntity, + ToolProviderCredentialApiEntity, + ToolProviderCredentialInfoApiEntity, +) +from core.tools.entities.tool_entities import CredentialType +from core.tools.errors import ToolProviderNotFoundError +from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_provider_encrypter +from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params from extensions.ext_database import db -from models.tools import BuiltinToolProvider +from extensions.ext_redis import redis_client +from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient +from services.plugin.plugin_service import PluginService from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) class BuiltinToolManageService: + __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 + + @staticmethod + def delete_custom_oauth_client_params(tenant_id: str, provider: str): + """ + delete custom oauth client params + """ + tool_provider = ToolProviderID(provider) + with Session(db.engine) as session: + session.query(ToolOAuthTenantClient).filter_by( + tenant_id=tenant_id, + provider=tool_provider.provider_name, + plugin_id=tool_provider.plugin_id, + ).delete() + session.commit() + return {"result": "success"} + + @staticmethod + def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str): + """ + get builtin tool provider oauth client schema + """ + provider = ToolManager.get_builtin_provider(provider_name, tenant_id) + verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified( + tenant_id, provider.plugin_unique_identifier + ) + + is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled( + tenant_id, provider_name + ) + is_system_oauth_params_exists = verified and BuiltinToolManageService.is_oauth_system_client_exists( + provider_name + ) + result = { + "schema": provider.get_oauth_client_schema(), + "is_oauth_custom_client_enabled": is_oauth_custom_client_enabled, + "is_system_oauth_params_exists": is_system_oauth_params_exists, + "client_params": BuiltinToolManageService.get_custom_oauth_client_params(tenant_id, provider_name), + "redirect_uri": f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback", + } + return result + @staticmethod def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: """ @@ -36,27 +92,11 @@ class BuiltinToolManageService: provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) tools = provider_controller.get_tools() - tool_provider_configurations = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - # check if user has added the provider - builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) - - credentials = {} - if builtin_provider is not None: - # get credentials - credentials = builtin_provider.credentials - credentials = tool_provider_configurations.decrypt(credentials) - result: list[ToolApiEntity] = [] for tool in tools or []: result.append( ToolTransformService.convert_tool_entity_to_api_entity( tool=tool, - credentials=credentials, tenant_id=tenant_id, labels=ToolLabelManager.get_tool_labels(provider_controller), ) @@ -65,25 +105,15 @@ class BuiltinToolManageService: return result @staticmethod - def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str): + def get_builtin_tool_provider_info(tenant_id: str, provider: str): """ get builtin tool provider info """ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - tool_provider_configurations = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) # check if user has added the provider - builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) - - credentials = {} - if builtin_provider is not None: - # get credentials - credentials = builtin_provider.credentials - credentials = tool_provider_configurations.decrypt(credentials) + builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id) + if builtin_provider is None: + raise ValueError(f"you have not added provider {provider}") entity = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, @@ -92,128 +122,407 @@ class BuiltinToolManageService: ) entity.original_credentials = {} - return entity @staticmethod - def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str): + def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str): """ list builtin provider credentials schema + :param credential_type: credential type :param provider_name: the name of the provider :param tenant_id: the id of the tenant :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) - return jsonable_encoder(provider.get_credentials_schema()) + return provider.get_credentials_schema_by_type(credential_type) @staticmethod def update_builtin_tool_provider( - session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict + user_id: str, + tenant_id: str, + provider: str, + credential_id: str, + credentials: dict | None = None, + name: str | None = None, ): """ update builtin tool provider """ - # get if the provider exists - provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) - - try: - # get provider - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) - if not provider_controller.need_credentials: - raise ValueError(f"provider {provider_name} does not need credentials") - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + with Session(db.engine) as session: + # get if the provider exists + db_provider = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() ) + if db_provider is None: + raise ValueError(f"you have not added provider {provider}") - # get original credentials if exists - if provider is not None: - original_credentials = tool_configuration.decrypt(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: - credentials[name] = original_credentials[name] - # validate credentials - provider_controller.validate_credentials(user_id, credentials) - # encrypt credentials - credentials = tool_configuration.encrypt(credentials) - except ( - PluginDaemonClientSideError, - ToolProviderNotFoundError, - ToolNotFoundError, - ToolProviderCredentialValidationError, - ) as e: - raise ValueError(str(e)) + try: + if CredentialType.of(db_provider.credential_type).is_editable() and credentials: + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider} does not need credentials") - if provider is None: - # create provider - provider = BuiltinToolProvider( - tenant_id=tenant_id, - user_id=user_id, - provider=provider_name, - encrypted_credentials=json.dumps(credentials), - ) + encrypter, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, db_provider, provider, provider_controller + ) - db.session.add(provider) - else: - provider.encrypted_credentials = json.dumps(credentials) + original_credentials = encrypter.decrypt(db_provider.credentials) + new_credentials: dict = { + key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE) + for key, value in credentials.items() + } - # delete cache - tool_configuration.delete_tool_credentials_cache() + if CredentialType.of(db_provider.credential_type).is_validate_allowed(): + provider_controller.validate_credentials(user_id, new_credentials) - db.session.commit() + # encrypt credentials + db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials)) + + cache.delete() + + # update name if provided + if name and name != db_provider.name: + # check if the name is already used + if ( + session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, provider=provider, name=name) + .count() + > 0 + ): + raise ValueError(f"the credential name '{name}' is already used") + + db_provider.name = name + + session.commit() + except Exception as e: + session.rollback() + raise ValueError(str(e)) return {"result": "success"} @staticmethod - def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): + def add_builtin_tool_provider( + user_id: str, + api_type: CredentialType, + tenant_id: str, + provider: str, + credentials: dict, + name: str | None = None, + ): + """ + add builtin tool provider + """ + try: + with Session(db.engine) as session: + lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" + with redis_client.lock(lock, timeout=20): + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider} does not need credentials") + + provider_count = ( + session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() + ) + + # check if the provider count is reached the limit + if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__: + raise ValueError(f"you have reached the maximum number of providers for {provider}") + + # validate credentials if allowed + if CredentialType.of(api_type).is_validate_allowed(): + provider_controller.validate_credentials(user_id, credentials) + + # generate name if not provided + if name is None or name == "": + name = BuiltinToolManageService.generate_builtin_tool_provider_name( + session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type + ) + else: + # check if the name is already used + if ( + session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, provider=provider, name=name) + .count() + > 0 + ): + raise ValueError(f"the credential name '{name}' is already used") + + # create encrypter + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(api_type) + ], + cache=NoOpProviderCredentialCache(), + ) + + db_provider = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider, + encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), + credential_type=api_type.value, + name=name, + ) + + session.add(db_provider) + session.commit() + except Exception as e: + session.rollback() + raise ValueError(str(e)) + return {"result": "success"} + + @staticmethod + def create_tool_encrypter( + tenant_id: str, + db_provider: BuiltinToolProvider, + provider: str, + provider_controller: BuiltinToolProviderController, + ): + encrypter, cache = create_provider_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type) + ], + cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id), + ) + return encrypter, cache + + @staticmethod + def generate_builtin_tool_provider_name( + session: Session, tenant_id: str, provider: str, credential_type: CredentialType + ) -> str: + try: + db_providers = ( + session.query(BuiltinToolProvider) + .filter_by( + tenant_id=tenant_id, + provider=provider, + credential_type=credential_type.value, + ) + .order_by(BuiltinToolProvider.created_at.desc()) + .all() + ) + + # Get the default name pattern + default_pattern = f"{credential_type.get_name()}" + + # Find all names that match the default pattern: "{default_pattern} {number}" + pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$" + numbers = [] + + for db_provider in db_providers: + if db_provider.name: + match = re.match(pattern, db_provider.name.strip()) + if match: + numbers.append(int(match.group(1))) + + # If no default pattern names found, start with 1 + if not numbers: + return f"{default_pattern} 1" + + # Find the next number + max_number = max(numbers) + return f"{default_pattern} {max_number + 1}" + except Exception as e: + logger.warning(f"Error generating next provider name for {provider}: {str(e)}") + # fallback + return f"{credential_type.get_name()} 1" + + @staticmethod + def get_builtin_tool_provider_credentials( + tenant_id: str, provider_name: str + ) -> list[ToolProviderCredentialApiEntity]: """ get builtin tool provider credentials """ - provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) + with db.session.no_autoflush: + providers = ( + db.session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, provider=provider_name) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .all() + ) - if provider_obj is None: - return {} + if len(providers) == 0: + return [] - provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id) - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - credentials = tool_configuration.decrypt(provider_obj.credentials) - credentials = tool_configuration.mask_tool_credentials(credentials) - return credentials + default_provider = providers[0] + default_provider.is_default = True + provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) + + credentials: list[ToolProviderCredentialApiEntity] = [] + encrypters = {} + for provider in providers: + credential_type = provider.credential_type + if credential_type not in encrypters: + encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter( + tenant_id, provider, provider.provider, provider_controller + )[0] + encrypter = encrypters[credential_type] + decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials)) + credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( + provider=provider, + credentials=decrypt_credential, + ) + credentials.append(credential_entity) + return credentials @staticmethod - def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str): + def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity: + """ + get builtin tool provider credential info + """ + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + supported_credential_types = provider_controller.get_supported_credential_types() + credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider) + credential_info = ToolProviderCredentialInfoApiEntity( + supported_credential_types=supported_credential_types, + is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider), + credentials=credentials, + ) + + return credential_info + + @staticmethod + def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str): """ delete tool provider """ - provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) + with Session(db.engine) as session: + db_provider = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() + ) - if provider_obj is None: - raise ValueError(f"you have not added provider {provider_name}") + if db_provider is None: + raise ValueError(f"you have not added provider {provider}") - db.session.delete(provider_obj) - db.session.commit() + session.delete(db_provider) + session.commit() - # delete cache - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - tool_configuration.delete_tool_credentials_cache() + # delete cache + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + _, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, db_provider, provider, provider_controller + ) + cache.delete() return {"result": "success"} + @staticmethod + def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str): + """ + set default provider + """ + with Session(db.engine) as session: + # get provider + target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first() + if target_provider is None: + raise ValueError("provider not found") + + # clear default provider + session.query(BuiltinToolProvider).filter_by( + tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True + ).update({"is_default": False}) + + # set new default provider + target_provider.is_default = True + session.commit() + return {"result": "success"} + + @staticmethod + def is_oauth_system_client_exists(provider_name: str) -> bool: + """ + check if oauth system client exists + """ + tool_provider = ToolProviderID(provider_name) + with Session(db.engine).no_autoflush as session: + system_client: ToolOAuthSystemClient | None = ( + session.query(ToolOAuthSystemClient) + .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) + .first() + ) + return system_client is not None + + @staticmethod + def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool: + """ + check if oauth custom client is enabled + """ + tool_provider = ToolProviderID(provider) + with Session(db.engine).no_autoflush as session: + user_client: ToolOAuthTenantClient | None = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + provider=tool_provider.provider_name, + plugin_id=tool_provider.plugin_id, + enabled=True, + ) + .first() + ) + return user_client is not None and user_client.enabled + + @staticmethod + def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None: + """ + get builtin tool provider + """ + tool_provider = ToolProviderID(provider) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + with Session(db.engine).no_autoflush as session: + user_client: ToolOAuthTenantClient | None = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + provider=tool_provider.provider_name, + plugin_id=tool_provider.plugin_id, + enabled=True, + ) + .first() + ) + oauth_params: Mapping[str, Any] | None = None + if user_client: + oauth_params = encrypter.decrypt(user_client.oauth_params) + return oauth_params + + # only verified provider can use custom oauth client + is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified( + tenant_id, provider.plugin_unique_identifier + ) + if not is_verified: + return oauth_params + + system_client: ToolOAuthSystemClient | None = ( + session.query(ToolOAuthSystemClient) + .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) + .first() + ) + if system_client: + try: + oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) + except Exception as e: + raise ValueError(f"Error decrypting system oauth params: {e}") + + return oauth_params + @staticmethod def get_builtin_tool_provider_icon(provider: str): """ @@ -234,9 +543,7 @@ class BuiltinToolManageService: with db.session.no_autoflush: # get all user added providers - db_providers: list[BuiltinToolProvider] = ( - db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or [] - ) + db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id) # rewrite db_providers for db_provider in db_providers: @@ -275,7 +582,6 @@ class BuiltinToolManageService: ToolTransformService.convert_tool_entity_to_api_entity( tenant_id=tenant_id, tool=tool, - credentials=user_builtin_provider.original_credentials, labels=ToolLabelManager.get_tool_labels(provider_controller), ) ) @@ -287,43 +593,153 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) @staticmethod - def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: - try: - full_provider_name = provider_name - provider_id_entity = ToolProviderID(provider_name) - provider_name = provider_id_entity.provider_name - if provider_id_entity.organization != "langgenius": - provider_obj = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == full_provider_name, + def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: + """ + This method is used to fetch the builtin provider from the database + 1.if the default provider exists, return the default provider + 2.if the default provider does not exist, return the oldest provider + """ + with Session(db.engine) as session: + try: + full_provider_name = provider_name + provider_id_entity = ToolProviderID(provider_name) + provider_name = provider_id_entity.provider_name + + if provider_id_entity.organization != "langgenius": + provider = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == full_provider_name, + ) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first + ) + .first() ) - .first() - ) - else: - provider_obj = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == provider_name) - | (BuiltinToolProvider.provider == full_provider_name), + else: + provider = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == provider_name) + | (BuiltinToolProvider.provider == full_provider_name), + ) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first + ) + .first() + ) + + if provider is None: + return None + + provider.provider = ToolProviderID(provider.provider).to_string() + return provider + except Exception: + # it's an old provider without organization + return ( + session.query(BuiltinToolProvider) + .filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first ) .first() ) - if provider_obj is None: - return None + @staticmethod + def save_custom_oauth_client_params( + tenant_id: str, + provider: str, + client_params: Optional[dict] = None, + enable_oauth_custom_client: Optional[bool] = None, + ): + """ + setup oauth custom client + """ + if client_params is None and enable_oauth_custom_client is None: + return {"result": "success"} - provider_obj.provider = ToolProviderID(provider_obj.provider).to_string() - return provider_obj - except Exception: - # it's an old provider without organization - return ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == provider_name), + tool_provider = ToolProviderID(provider) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller: + raise ToolProviderNotFoundError(f"Provider {provider} not found") + + if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)): + raise ValueError(f"Provider {provider} is not a builtin or plugin provider") + + with Session(db.engine) as session: + custom_client_params = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, ) .first() ) + + # if the record does not exist, create a basic record + if custom_client_params is None: + custom_client_params = ToolOAuthTenantClient( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, + ) + session.add(custom_client_params) + + if client_params is not None: + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + original_params = encrypter.decrypt(custom_client_params.oauth_params) + new_params: dict = { + key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE) + for key, value in client_params.items() + } + custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params)) + + if enable_oauth_custom_client is not None: + custom_client_params.enabled = enable_oauth_custom_client + + session.commit() + return {"result": "success"} + + @staticmethod + def get_custom_oauth_client_params(tenant_id: str, provider: str): + """ + get custom oauth client params + """ + with Session(db.engine) as session: + tool_provider = ToolProviderID(provider) + custom_oauth_client_params: ToolOAuthTenantClient | None = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, + ) + .first() + ) + if custom_oauth_client_params is None: + return {} + + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller: + raise ToolProviderNotFoundError(f"Provider {provider} not found") + + if not isinstance(provider_controller, BuiltinToolProviderController): + raise ValueError(f"Provider {provider} is not a builtin or plugin provider") + + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + + return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params)) diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_mange_service.py index 7c23abda4..fda6da598 100644 --- a/api/services/tools/mcp_tools_mange_service.py +++ b/api/services/tools/mcp_tools_mange_service.py @@ -7,13 +7,14 @@ from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError from core.helper import encrypter +from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.error import MCPAuthError, MCPError from core.mcp.mcp_client import MCPClient from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType from core.tools.mcp_tool.provider import MCPToolProviderController -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import ProviderConfigEncrypter from extensions.ext_database import db from models.tools import MCPToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -69,6 +70,7 @@ class MCPToolManageService: MCPToolProvider.server_url_hash == server_url_hash, MCPToolProvider.server_identifier == server_identifier, ), + MCPToolProvider.tenant_id == tenant_id, ) .first() ) @@ -197,8 +199,7 @@ class MCPToolManageService: tool_configuration = ProviderConfigEncrypter( tenant_id=mcp_provider.tenant_id, config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.provider_id, + provider_config_cache=NoOpProviderCredentialCache(), ) credentials = tool_configuration.encrypt(credentials) mcp_provider.updated_at = datetime.now() diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 3d0c35cd9..36b892e20 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,21 +5,23 @@ from typing import Any, Optional, Union, cast from yarl import URL from configs import dify_config +from core.helper.provider_cache import ToolProviderCredentialsCache from core.mcp.types import Tool as MCPTool from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.custom_tool.provider import ApiToolProviderController -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, + CredentialType, ToolParameter, ToolProviderType, ) from core.tools.plugin_tool.provider import PluginToolProviderController -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider @@ -119,7 +121,12 @@ class ToolTransformService: result.plugin_unique_identifier = provider_controller.plugin_unique_identifier # get credentials schema - schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()} + schema = { + x.to_basic_provider_config().name: x + for x in provider_controller.get_credentials_schema_by_type( + CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY + ) + } for name, value in schema.items(): if result.masked_credentials: @@ -136,15 +143,23 @@ class ToolTransformService: credentials = db_provider.credentials # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_provider_encrypter( tenant_id=db_provider.tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type( + CredentialType.of(db_provider.credential_type) + ) + ], + cache=ToolProviderCredentialsCache( + tenant_id=db_provider.tenant_id, + provider=db_provider.provider, + credential_id=db_provider.id, + ), ) # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt(data=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) + decrypted_credentials = encrypter.decrypt(data=credentials) + masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials) result.masked_credentials = masked_credentials result.original_credentials = decrypted_credentials @@ -287,16 +302,14 @@ class ToolTransformService: if decrypt_credentials: # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=db_provider.tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt(data=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) + decrypted_credentials = encrypter.decrypt(data=credentials) + masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials) result.masked_credentials = masked_credentials @@ -306,7 +319,6 @@ class ToolTransformService: def convert_tool_entity_to_api_entity( tool: Union[ApiToolBundle, WorkflowTool, Tool], tenant_id: str, - credentials: dict | None = None, labels: list[str] | None = None, ) -> ToolApiEntity: """ @@ -316,7 +328,7 @@ class ToolTransformService: # fork tool runtime tool = tool.fork_tool_runtime( runtime=ToolRuntime( - credentials=credentials or {}, + credentials={}, tenant_id=tenant_id, ) ) @@ -357,6 +369,19 @@ class ToolTransformService: labels=labels or [], ) + @staticmethod + def convert_builtin_provider_to_credential_entity( + provider: BuiltinToolProvider, credentials: dict + ) -> ToolProviderCredentialApiEntity: + return ToolProviderCredentialApiEntity( + id=provider.id, + name=provider.name, + provider=provider.provider, + credential_type=CredentialType.of(provider.credential_type), + is_default=provider.is_default, + credentials=credentials, + ) + @staticmethod def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]: """ diff --git a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py new file mode 100644 index 000000000..30990f8d5 --- /dev/null +++ b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py @@ -0,0 +1,619 @@ +import base64 +import hashlib +from unittest.mock import patch + +import pytest +from Crypto.Cipher import AES +from Crypto.Random import get_random_bytes +from Crypto.Util.Padding import pad + +from core.tools.utils.system_oauth_encryption import ( + OAuthEncryptionError, + SystemOAuthEncrypter, + create_system_oauth_encrypter, + decrypt_system_oauth_params, + encrypt_system_oauth_params, + get_system_oauth_encrypter, +) + + +class TestSystemOAuthEncrypter: + """Test cases for SystemOAuthEncrypter class""" + + def test_init_with_secret_key(self): + """Test initialization with provided secret key""" + secret_key = "test_secret_key" + encrypter = SystemOAuthEncrypter(secret_key=secret_key) + expected_key = hashlib.sha256(secret_key.encode()).digest() + assert encrypter.key == expected_key + + def test_init_with_none_secret_key(self): + """Test initialization with None secret key falls back to config""" + with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "config_secret" + encrypter = SystemOAuthEncrypter(secret_key=None) + expected_key = hashlib.sha256(b"config_secret").digest() + assert encrypter.key == expected_key + + def test_init_with_empty_secret_key(self): + """Test initialization with empty secret key""" + encrypter = SystemOAuthEncrypter(secret_key="") + expected_key = hashlib.sha256(b"").digest() + assert encrypter.key == expected_key + + def test_init_without_secret_key_uses_config(self): + """Test initialization without secret key uses config""" + with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "default_secret" + encrypter = SystemOAuthEncrypter() + expected_key = hashlib.sha256(b"default_secret").digest() + assert encrypter.key == expected_key + + def test_encrypt_oauth_params_basic(self): + """Test basic OAuth parameters encryption""" + encrypter = SystemOAuthEncrypter("test_secret") + oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted = encrypter.encrypt_oauth_params(oauth_params) + + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + # Should be valid base64 + try: + base64.b64decode(encrypted) + except Exception: + pytest.fail("Encrypted result is not valid base64") + + def test_encrypt_oauth_params_empty_dict(self): + """Test encryption with empty dictionary""" + encrypter = SystemOAuthEncrypter("test_secret") + oauth_params = {} + + encrypted = encrypter.encrypt_oauth_params(oauth_params) + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_encrypt_oauth_params_complex_data(self): + """Test encryption with complex data structures""" + encrypter = SystemOAuthEncrypter("test_secret") + oauth_params = { + "client_id": "test_id", + "client_secret": "test_secret", + "scopes": ["read", "write", "admin"], + "metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True}, + "numeric_value": 42, + "boolean_value": False, + "null_value": None, + } + + encrypted = encrypter.encrypt_oauth_params(oauth_params) + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_encrypt_oauth_params_unicode_data(self): + """Test encryption with unicode data""" + encrypter = SystemOAuthEncrypter("test_secret") + oauth_params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"} + + encrypted = encrypter.encrypt_oauth_params(oauth_params) + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_encrypt_oauth_params_large_data(self): + """Test encryption with large data""" + encrypter = SystemOAuthEncrypter("test_secret") + oauth_params = { + "client_id": "test_id", + "large_data": "x" * 10000, # 10KB of data + } + + encrypted = encrypter.encrypt_oauth_params(oauth_params) + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_encrypt_oauth_params_invalid_input(self): + """Test encryption with invalid input types""" + encrypter = SystemOAuthEncrypter("test_secret") + + with pytest.raises(Exception): # noqa: B017 + encrypter.encrypt_oauth_params(None) # type: ignore + + with pytest.raises(Exception): # noqa: B017 + encrypter.encrypt_oauth_params("not_a_dict") # type: ignore + + def test_decrypt_oauth_params_basic(self): + """Test basic OAuth parameters decryption""" + encrypter = SystemOAuthEncrypter("test_secret") + original_params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted = encrypter.encrypt_oauth_params(original_params) + decrypted = encrypter.decrypt_oauth_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_oauth_params_empty_dict(self): + """Test decryption of empty dictionary""" + encrypter = SystemOAuthEncrypter("test_secret") + original_params = {} + + encrypted = encrypter.encrypt_oauth_params(original_params) + decrypted = encrypter.decrypt_oauth_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_oauth_params_complex_data(self): + """Test decryption with complex data structures""" + encrypter = SystemOAuthEncrypter("test_secret") + original_params = { + "client_id": "test_id", + "client_secret": "test_secret", + "scopes": ["read", "write", "admin"], + "metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True}, + "numeric_value": 42, + "boolean_value": False, + "null_value": None, + } + + encrypted = encrypter.encrypt_oauth_params(original_params) + decrypted = encrypter.decrypt_oauth_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_oauth_params_unicode_data(self): + """Test decryption with unicode data""" + encrypter = SystemOAuthEncrypter("test_secret") + original_params = { + "client_id": "test_id", + "client_secret": "test_secret", + "description": "This is a test case 🚀", + } + + encrypted = encrypter.encrypt_oauth_params(original_params) + decrypted = encrypter.decrypt_oauth_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_oauth_params_large_data(self): + """Test decryption with large data""" + encrypter = SystemOAuthEncrypter("test_secret") + original_params = { + "client_id": "test_id", + "large_data": "x" * 10000, # 10KB of data + } + + encrypted = encrypter.encrypt_oauth_params(original_params) + decrypted = encrypter.decrypt_oauth_params(encrypted) + + assert decrypted == original_params + + def test_decrypt_oauth_params_invalid_base64(self): + """Test decryption with invalid base64 data""" + encrypter = SystemOAuthEncrypter("test_secret") + + with pytest.raises(OAuthEncryptionError): + encrypter.decrypt_oauth_params("invalid_base64!") + + def test_decrypt_oauth_params_empty_string(self): + """Test decryption with empty string""" + encrypter = SystemOAuthEncrypter("test_secret") + + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_oauth_params("") + + assert "encrypted_data cannot be empty" in str(exc_info.value) + + def test_decrypt_oauth_params_non_string_input(self): + """Test decryption with non-string input""" + encrypter = SystemOAuthEncrypter("test_secret") + + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_oauth_params(123) # type: ignore + + assert "encrypted_data must be a string" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_oauth_params(None) # type: ignore + + assert "encrypted_data must be a string" in str(exc_info.value) + + def test_decrypt_oauth_params_too_short_data(self): + """Test decryption with too short encrypted data""" + encrypter = SystemOAuthEncrypter("test_secret") + + # Create data that's too short (less than 32 bytes) + short_data = base64.b64encode(b"short").decode() + + with pytest.raises(OAuthEncryptionError) as exc_info: + encrypter.decrypt_oauth_params(short_data) + + assert "Invalid encrypted data format" in str(exc_info.value) + + def test_decrypt_oauth_params_corrupted_data(self): + """Test decryption with corrupted data""" + encrypter = SystemOAuthEncrypter("test_secret") + + # Create corrupted data (valid base64 but invalid encrypted content) + corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage + + with pytest.raises(OAuthEncryptionError): + encrypter.decrypt_oauth_params(corrupted_data) + + def test_decrypt_oauth_params_wrong_key(self): + """Test decryption with wrong key""" + encrypter1 = SystemOAuthEncrypter("secret1") + encrypter2 = SystemOAuthEncrypter("secret2") + + original_params = {"client_id": "test_id", "client_secret": "test_secret"} + encrypted = encrypter1.encrypt_oauth_params(original_params) + + with pytest.raises(OAuthEncryptionError): + encrypter2.decrypt_oauth_params(encrypted) + + def test_encryption_decryption_consistency(self): + """Test that encryption and decryption are consistent""" + encrypter = SystemOAuthEncrypter("test_secret") + + test_cases = [ + {}, + {"simple": "value"}, + {"client_id": "id", "client_secret": "secret"}, + {"complex": {"nested": {"deep": "value"}}}, + {"unicode": "test 🚀"}, + {"numbers": 42, "boolean": True, "null": None}, + {"array": [1, 2, 3, "four", {"five": 5}]}, + ] + + for original_params in test_cases: + encrypted = encrypter.encrypt_oauth_params(original_params) + decrypted = encrypter.decrypt_oauth_params(encrypted) + assert decrypted == original_params, f"Failed for case: {original_params}" + + def test_encryption_randomness(self): + """Test that encryption produces different results for same input""" + encrypter = SystemOAuthEncrypter("test_secret") + oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted1 = encrypter.encrypt_oauth_params(oauth_params) + encrypted2 = encrypter.encrypt_oauth_params(oauth_params) + + # Should be different due to random IV + assert encrypted1 != encrypted2 + + # But should decrypt to same result + decrypted1 = encrypter.decrypt_oauth_params(encrypted1) + decrypted2 = encrypter.decrypt_oauth_params(encrypted2) + assert decrypted1 == decrypted2 == oauth_params + + def test_different_secret_keys_produce_different_results(self): + """Test that different secret keys produce different encrypted results""" + encrypter1 = SystemOAuthEncrypter("secret1") + encrypter2 = SystemOAuthEncrypter("secret2") + + oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted1 = encrypter1.encrypt_oauth_params(oauth_params) + encrypted2 = encrypter2.encrypt_oauth_params(oauth_params) + + # Should produce different encrypted results + assert encrypted1 != encrypted2 + + # But each should decrypt correctly with its own key + decrypted1 = encrypter1.decrypt_oauth_params(encrypted1) + decrypted2 = encrypter2.decrypt_oauth_params(encrypted2) + assert decrypted1 == decrypted2 == oauth_params + + @patch("core.tools.utils.system_oauth_encryption.get_random_bytes") + def test_encrypt_oauth_params_crypto_error(self, mock_get_random_bytes): + """Test encryption when crypto operation fails""" + mock_get_random_bytes.side_effect = Exception("Crypto error") + + encrypter = SystemOAuthEncrypter("test_secret") + oauth_params = {"client_id": "test_id"} + + with pytest.raises(OAuthEncryptionError) as exc_info: + encrypter.encrypt_oauth_params(oauth_params) + + assert "Encryption failed" in str(exc_info.value) + + @patch("core.tools.utils.system_oauth_encryption.TypeAdapter") + def test_encrypt_oauth_params_serialization_error(self, mock_type_adapter): + """Test encryption when JSON serialization fails""" + mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error") + + encrypter = SystemOAuthEncrypter("test_secret") + oauth_params = {"client_id": "test_id"} + + with pytest.raises(OAuthEncryptionError) as exc_info: + encrypter.encrypt_oauth_params(oauth_params) + + assert "Encryption failed" in str(exc_info.value) + + def test_decrypt_oauth_params_invalid_json(self): + """Test decryption with invalid JSON data""" + encrypter = SystemOAuthEncrypter("test_secret") + + # Create valid encrypted data but with invalid JSON content + iv = get_random_bytes(16) + cipher = AES.new(encrypter.key, AES.MODE_CBC, iv) + invalid_json = b"invalid json content" + padded_data = pad(invalid_json, AES.block_size) + encrypted_data = cipher.encrypt(padded_data) + combined = iv + encrypted_data + encoded = base64.b64encode(combined).decode() + + with pytest.raises(OAuthEncryptionError): + encrypter.decrypt_oauth_params(encoded) + + def test_key_derivation_consistency(self): + """Test that key derivation is consistent""" + secret_key = "test_secret" + encrypter1 = SystemOAuthEncrypter(secret_key) + encrypter2 = SystemOAuthEncrypter(secret_key) + + assert encrypter1.key == encrypter2.key + + # Keys should be 32 bytes (256 bits) + assert len(encrypter1.key) == 32 + + +class TestFactoryFunctions: + """Test cases for factory functions""" + + def test_create_system_oauth_encrypter_with_secret(self): + """Test factory function with secret key""" + secret_key = "test_secret" + encrypter = create_system_oauth_encrypter(secret_key) + + assert isinstance(encrypter, SystemOAuthEncrypter) + expected_key = hashlib.sha256(secret_key.encode()).digest() + assert encrypter.key == expected_key + + def test_create_system_oauth_encrypter_without_secret(self): + """Test factory function without secret key""" + with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "config_secret" + encrypter = create_system_oauth_encrypter() + + assert isinstance(encrypter, SystemOAuthEncrypter) + expected_key = hashlib.sha256(b"config_secret").digest() + assert encrypter.key == expected_key + + def test_create_system_oauth_encrypter_with_none_secret(self): + """Test factory function with None secret key""" + with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "config_secret" + encrypter = create_system_oauth_encrypter(None) + + assert isinstance(encrypter, SystemOAuthEncrypter) + expected_key = hashlib.sha256(b"config_secret").digest() + assert encrypter.key == expected_key + + +class TestGlobalEncrypterInstance: + """Test cases for global encrypter instance""" + + def test_get_system_oauth_encrypter_singleton(self): + """Test that get_system_oauth_encrypter returns singleton instance""" + # Clear the global instance first + import core.tools.utils.system_oauth_encryption + + core.tools.utils.system_oauth_encryption._oauth_encrypter = None + + encrypter1 = get_system_oauth_encrypter() + encrypter2 = get_system_oauth_encrypter() + + assert encrypter1 is encrypter2 + assert isinstance(encrypter1, SystemOAuthEncrypter) + + def test_get_system_oauth_encrypter_uses_config(self): + """Test that global encrypter uses config""" + # Clear the global instance first + import core.tools.utils.system_oauth_encryption + + core.tools.utils.system_oauth_encryption._oauth_encrypter = None + + with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: + mock_config.SECRET_KEY = "global_secret" + encrypter = get_system_oauth_encrypter() + + expected_key = hashlib.sha256(b"global_secret").digest() + assert encrypter.key == expected_key + + +class TestConvenienceFunctions: + """Test cases for convenience functions""" + + def test_encrypt_system_oauth_params(self): + """Test encrypt_system_oauth_params convenience function""" + oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted = encrypt_system_oauth_params(oauth_params) + + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + + def test_decrypt_system_oauth_params(self): + """Test decrypt_system_oauth_params convenience function""" + oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} + + encrypted = encrypt_system_oauth_params(oauth_params) + decrypted = decrypt_system_oauth_params(encrypted) + + assert decrypted == oauth_params + + def test_convenience_functions_consistency(self): + """Test that convenience functions work consistently""" + test_cases = [ + {}, + {"simple": "value"}, + {"client_id": "id", "client_secret": "secret"}, + {"complex": {"nested": {"deep": "value"}}}, + {"unicode": "test 🚀"}, + {"numbers": 42, "boolean": True, "null": None}, + ] + + for original_params in test_cases: + encrypted = encrypt_system_oauth_params(original_params) + decrypted = decrypt_system_oauth_params(encrypted) + assert decrypted == original_params, f"Failed for case: {original_params}" + + def test_convenience_functions_with_errors(self): + """Test convenience functions with error conditions""" + # Test encryption with invalid input + with pytest.raises(Exception): # noqa: B017 + encrypt_system_oauth_params(None) # type: ignore + + # Test decryption with invalid input + with pytest.raises(ValueError): + decrypt_system_oauth_params("") + + with pytest.raises(ValueError): + decrypt_system_oauth_params(None) # type: ignore + + +class TestErrorHandling: + """Test cases for error handling""" + + def test_oauth_encryption_error_inheritance(self): + """Test that OAuthEncryptionError is a proper exception""" + error = OAuthEncryptionError("Test error") + assert isinstance(error, Exception) + assert str(error) == "Test error" + + def test_oauth_encryption_error_with_cause(self): + """Test OAuthEncryptionError with cause""" + original_error = ValueError("Original error") + error = OAuthEncryptionError("Wrapper error") + error.__cause__ = original_error + + assert isinstance(error, Exception) + assert str(error) == "Wrapper error" + assert error.__cause__ is original_error + + def test_error_messages_are_informative(self): + """Test that error messages are informative""" + encrypter = SystemOAuthEncrypter("test_secret") + + # Test empty string error + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_oauth_params("") + assert "encrypted_data cannot be empty" in str(exc_info.value) + + # Test non-string error + with pytest.raises(ValueError) as exc_info: + encrypter.decrypt_oauth_params(123) # type: ignore + assert "encrypted_data must be a string" in str(exc_info.value) + + # Test invalid format error + short_data = base64.b64encode(b"short").decode() + with pytest.raises(OAuthEncryptionError) as exc_info: + encrypter.decrypt_oauth_params(short_data) + assert "Invalid encrypted data format" in str(exc_info.value) + + +class TestEdgeCases: + """Test cases for edge cases and boundary conditions""" + + def test_very_long_secret_key(self): + """Test with very long secret key""" + long_secret = "x" * 10000 + encrypter = SystemOAuthEncrypter(long_secret) + + # Key should still be 32 bytes due to SHA-256 + assert len(encrypter.key) == 32 + + # Should still work normally + oauth_params = {"client_id": "test_id"} + encrypted = encrypter.encrypt_oauth_params(oauth_params) + decrypted = encrypter.decrypt_oauth_params(encrypted) + assert decrypted == oauth_params + + def test_special_characters_in_secret_key(self): + """Test with special characters in secret key""" + special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀" + encrypter = SystemOAuthEncrypter(special_secret) + + oauth_params = {"client_id": "test_id"} + encrypted = encrypter.encrypt_oauth_params(oauth_params) + decrypted = encrypter.decrypt_oauth_params(encrypted) + assert decrypted == oauth_params + + def test_empty_values_in_oauth_params(self): + """Test with empty values in oauth params""" + oauth_params = { + "client_id": "", + "client_secret": "", + "empty_dict": {}, + "empty_list": [], + "empty_string": "", + "zero": 0, + "false": False, + "none": None, + } + + encrypter = SystemOAuthEncrypter("test_secret") + encrypted = encrypter.encrypt_oauth_params(oauth_params) + decrypted = encrypter.decrypt_oauth_params(encrypted) + assert decrypted == oauth_params + + def test_deeply_nested_oauth_params(self): + """Test with deeply nested oauth params""" + oauth_params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}} + + encrypter = SystemOAuthEncrypter("test_secret") + encrypted = encrypter.encrypt_oauth_params(oauth_params) + decrypted = encrypter.decrypt_oauth_params(encrypted) + assert decrypted == oauth_params + + def test_oauth_params_with_all_json_types(self): + """Test with all JSON-supported data types""" + oauth_params = { + "string": "test_string", + "integer": 42, + "float": 3.14159, + "boolean_true": True, + "boolean_false": False, + "null_value": None, + "empty_string": "", + "array": [1, "two", 3.0, True, False, None], + "object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True}, + } + + encrypter = SystemOAuthEncrypter("test_secret") + encrypted = encrypter.encrypt_oauth_params(oauth_params) + decrypted = encrypter.decrypt_oauth_params(encrypted) + assert decrypted == oauth_params + + +class TestPerformance: + """Test cases for performance considerations""" + + def test_large_oauth_params(self): + """Test with large oauth params""" + large_value = "x" * 100000 # 100KB + oauth_params = {"client_id": "test_id", "large_data": large_value} + + encrypter = SystemOAuthEncrypter("test_secret") + encrypted = encrypter.encrypt_oauth_params(oauth_params) + decrypted = encrypter.decrypt_oauth_params(encrypted) + assert decrypted == oauth_params + + def test_many_fields_oauth_params(self): + """Test with many fields in oauth params""" + oauth_params = {f"field_{i}": f"value_{i}" for i in range(1000)} + + encrypter = SystemOAuthEncrypter("test_secret") + encrypted = encrypter.encrypt_oauth_params(oauth_params) + decrypted = encrypter.decrypt_oauth_params(encrypted) + assert decrypted == oauth_params + + def test_repeated_encryption_decryption(self): + """Test repeated encryption and decryption operations""" + encrypter = SystemOAuthEncrypter("test_secret") + oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} + + # Test multiple rounds of encryption/decryption + for i in range(100): + encrypted = encrypter.encrypt_oauth_params(oauth_params) + decrypted = encrypter.decrypt_oauth_params(encrypted) + assert decrypted == oauth_params diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index a1b82ab2f..b4711ea39 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -18,7 +18,6 @@ import AppIcon from '@/app/components/base/app-icon' import Button from '@/app/components/base/button' import Indicator from '@/app/components/header/indicator' import Switch from '@/app/components/base/switch' -import Toast from '@/app/components/base/toast' import ConfigContext from '@/context/debug-configuration' import type { AgentTool } from '@/types/app' import { type Collection, CollectionType } from '@/app/components/tools/types' @@ -26,8 +25,6 @@ import { MAX_TOOLS_NUM } from '@/config' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import Tooltip from '@/app/components/base/tooltip' import { DefaultToolIcon } from '@/app/components/base/icons/src/public/other' -import ConfigCredential from '@/app/components/tools/setting/build-in/config-credentials' -import { updateBuiltInToolCredential } from '@/service/tools' import cn from '@/utils/classnames' import ToolPicker from '@/app/components/workflow/block-selector/tool-picker' import type { ToolDefaultValue, ToolValue } from '@/app/components/workflow/block-selector/types' @@ -57,13 +54,7 @@ const AgentTools: FC = () => { const formattingChangedDispatcher = useFormattingChangedDispatcher() const [currentTool, setCurrentTool] = useState(null) - const currentCollection = useMemo(() => { - if (!currentTool) return null - const collection = collectionList.find(collection => canFindTool(collection.id, currentTool?.provider_id) && collection.type === currentTool?.provider_type) - return collection - }, [currentTool, collectionList]) const [isShowSettingTool, setIsShowSettingTool] = useState(false) - const [isShowSettingAuth, setShowSettingAuth] = useState(false) const tools = (modelConfig?.agentConfig?.tools as AgentTool[] || []).map((item) => { const collection = collectionList.find( collection => @@ -100,17 +91,6 @@ const AgentTools: FC = () => { formattingChangedDispatcher() } - const handleToolAuthSetting = (value: AgentToolWithMoreInfo) => { - const newModelConfig = produce(modelConfig, (draft) => { - const tool = (draft.agentConfig.tools).find((item: any) => item.provider_id === value?.collection?.id && item.tool_name === value?.tool_name) - if (tool) - (tool as AgentTool).notAuthor = false - }) - setModelConfig(newModelConfig) - setIsShowSettingTool(false) - formattingChangedDispatcher() - } - const [isDeleting, setIsDeleting] = useState(-1) const getToolValue = (tool: ToolDefaultValue) => { return { @@ -144,6 +124,20 @@ const AgentTools: FC = () => { return item.provider_name } + const handleAuthorizationItemClick = useCallback((credentialId: string) => { + const newModelConfig = produce(modelConfig, (draft) => { + const tool = (draft.agentConfig.tools).find((item: any) => item.provider_id === currentTool?.provider_id) + if (tool) + (tool as AgentTool).credential_id = credentialId + }) + setCurrentTool({ + ...currentTool, + credential_id: credentialId, + } as any) + setModelConfig(newModelConfig) + formattingChangedDispatcher() + }, [currentTool, modelConfig, setModelConfig, formattingChangedDispatcher]) + return ( <> { {item.notAuthor && ( +
+ + ) + } + + + + + {bottomSlot} + + + + ) +} + +export default memo(Modal) diff --git a/web/app/components/base/select/pure.tsx b/web/app/components/base/select/pure.tsx index 81cc2fbad..be88c936f 100644 --- a/web/app/components/base/select/pure.tsx +++ b/web/app/components/base/select/pure.tsx @@ -39,6 +39,9 @@ type PureSelectProps = { itemClassName?: string title?: string }, + placeholder?: string + disabled?: boolean + triggerPopupSameWidth?: boolean } const PureSelect = ({ options, @@ -47,6 +50,9 @@ const PureSelect = ({ containerProps, triggerProps, popupProps, + placeholder, + disabled, + triggerPopupSameWidth, }: PureSelectProps) => { const { t } = useTranslation() const { @@ -74,7 +80,7 @@ const PureSelect = ({ }, [onOpenChange]) const selectedOption = options.find(option => option.value === value) - const triggerText = selectedOption?.label || t('common.placeholder.select') + const triggerText = selectedOption?.label || placeholder || t('common.placeholder.select') return ( handleOpenChange(!mergedOpen)} @@ -135,6 +142,7 @@ const PureSelect = ({ )} title={option.label} onClick={() => { + if (disabled) return onChange?.(option.value) handleOpenChange(false) }} diff --git a/web/app/components/plugins/plugin-auth/authorize/add-api-key-button.tsx b/web/app/components/plugins/plugin-auth/authorize/add-api-key-button.tsx new file mode 100644 index 000000000..295fc4fa9 --- /dev/null +++ b/web/app/components/plugins/plugin-auth/authorize/add-api-key-button.tsx @@ -0,0 +1,50 @@ +import { + memo, + useState, +} from 'react' +import Button from '@/app/components/base/button' +import type { ButtonProps } from '@/app/components/base/button' +import ApiKeyModal from './api-key-modal' +import type { PluginPayload } from '../types' + +export type AddApiKeyButtonProps = { + pluginPayload: PluginPayload + buttonVariant?: ButtonProps['variant'] + buttonText?: string + disabled?: boolean + onUpdate?: () => void +} +const AddApiKeyButton = ({ + pluginPayload, + buttonVariant = 'secondary-accent', + buttonText = 'use api key', + disabled, + onUpdate, +}: AddApiKeyButtonProps) => { + const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false) + + return ( + <> + + { + isApiKeyModalOpen && ( + setIsApiKeyModalOpen(false)} + onUpdate={onUpdate} + /> + ) + } + + + ) +} + +export default memo(AddApiKeyButton) diff --git a/web/app/components/plugins/plugin-auth/authorize/add-oauth-button.tsx b/web/app/components/plugins/plugin-auth/authorize/add-oauth-button.tsx new file mode 100644 index 000000000..599d70134 --- /dev/null +++ b/web/app/components/plugins/plugin-auth/authorize/add-oauth-button.tsx @@ -0,0 +1,259 @@ +import { + memo, + useCallback, + useMemo, + useState, +} from 'react' +import { useTranslation } from 'react-i18next' +import { + RiClipboardLine, + RiEqualizer2Line, + RiInformation2Fill, +} from '@remixicon/react' +import Button from '@/app/components/base/button' +import type { ButtonProps } from '@/app/components/base/button' +import OAuthClientSettings from './oauth-client-settings' +import cn from '@/utils/classnames' +import type { PluginPayload } from '../types' +import { openOAuthPopup } from '@/hooks/use-oauth' +import Badge from '@/app/components/base/badge' +import { + useGetPluginOAuthClientSchemaHook, + useGetPluginOAuthUrlHook, +} from '../hooks/use-credential' +import type { FormSchema } from '@/app/components/base/form/types' +import { FormTypeEnum } from '@/app/components/base/form/types' +import ActionButton from '@/app/components/base/action-button' +import { useRenderI18nObject } from '@/hooks/use-i18n' + +export type AddOAuthButtonProps = { + pluginPayload: PluginPayload + buttonVariant?: ButtonProps['variant'] + buttonText?: string + className?: string + buttonLeftClassName?: string + buttonRightClassName?: string + dividerClassName?: string + disabled?: boolean + onUpdate?: () => void +} +const AddOAuthButton = ({ + pluginPayload, + buttonVariant = 'primary', + buttonText = 'use oauth', + className, + buttonLeftClassName, + buttonRightClassName, + dividerClassName, + disabled, + onUpdate, +}: AddOAuthButtonProps) => { + const { t } = useTranslation() + const renderI18nObject = useRenderI18nObject() + const [isOAuthSettingsOpen, setIsOAuthSettingsOpen] = useState(false) + const { mutateAsync: getPluginOAuthUrl } = useGetPluginOAuthUrlHook(pluginPayload) + const { data, isLoading } = useGetPluginOAuthClientSchemaHook(pluginPayload) + const { + schema = [], + is_oauth_custom_client_enabled, + is_system_oauth_params_exists, + client_params, + redirect_uri, + } = data || {} + const isConfigured = is_system_oauth_params_exists || is_oauth_custom_client_enabled + const handleOAuth = useCallback(async () => { + const { authorization_url } = await getPluginOAuthUrl() + + if (authorization_url) { + openOAuthPopup( + authorization_url, + () => onUpdate?.(), + ) + } + }, [getPluginOAuthUrl, onUpdate]) + + const renderCustomLabel = useCallback((item: FormSchema) => { + return ( +
+
+
+ +
+
+
+ {t('plugin.auth.clientInfo')} +
+ { + redirect_uri && ( +
+
{redirect_uri}
+ { + navigator.clipboard.writeText(redirect_uri || '') + }} + > + + +
+ ) + } +
+
+
+ {renderI18nObject(item.label as Record)} + { + item.required && ( + * + ) + } +
+
+ ) + }, [t, redirect_uri, renderI18nObject]) + const memorizedSchemas = useMemo(() => { + const result: FormSchema[] = schema.map((item, index) => { + return { + ...item, + label: index === 0 ? renderCustomLabel(item) : item.label, + labelClassName: index === 0 ? 'h-auto' : undefined, + } + }) + if (is_system_oauth_params_exists) { + result.unshift({ + name: '__oauth_client__', + label: t('plugin.auth.oauthClient'), + type: FormTypeEnum.radio, + options: [ + { + label: t('plugin.auth.default'), + value: 'default', + }, + { + label: t('plugin.auth.custom'), + value: 'custom', + }, + ], + required: false, + default: is_oauth_custom_client_enabled ? 'custom' : 'default', + } as FormSchema) + result.forEach((item, index) => { + if (index > 0) { + item.show_on = [ + { + variable: '__oauth_client__', + value: 'custom', + }, + ] + if (client_params) + item.default = client_params[item.name] || item.default + } + }) + } + + return result + }, [schema, renderCustomLabel, t, is_system_oauth_params_exists, is_oauth_custom_client_enabled, client_params]) + + const __auth_client__ = useMemo(() => { + if (isConfigured) { + if (is_oauth_custom_client_enabled) + return 'custom' + return 'default' + } + else { + if (is_system_oauth_params_exists) + return 'default' + return 'custom' + } + }, [isConfigured, is_oauth_custom_client_enabled, is_system_oauth_params_exists]) + + return ( + <> + { + isConfigured && ( + + ) + } + { + !isConfigured && ( + + ) + } + { + isOAuthSettingsOpen && ( + setIsOAuthSettingsOpen(false)} + disabled={disabled || isLoading} + schemas={memorizedSchemas} + onAuth={handleOAuth} + editValues={{ + ...client_params, + __oauth_client__: __auth_client__, + }} + hasOriginalClientParams={Object.keys(client_params || {}).length > 0} + onUpdate={onUpdate} + /> + ) + } + + ) +} + +export default memo(AddOAuthButton) diff --git a/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx b/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx new file mode 100644 index 000000000..d582c660b --- /dev/null +++ b/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx @@ -0,0 +1,181 @@ +import { + memo, + useCallback, + useMemo, + useRef, + useState, +} from 'react' +import { useTranslation } from 'react-i18next' +import { RiExternalLinkLine } from '@remixicon/react' +import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' +import Modal from '@/app/components/base/modal/modal' +import { CredentialTypeEnum } from '../types' +import AuthForm from '@/app/components/base/form/form-scenarios/auth' +import type { FormRefObject } from '@/app/components/base/form/types' +import { FormTypeEnum } from '@/app/components/base/form/types' +import { useToastContext } from '@/app/components/base/toast' +import Loading from '@/app/components/base/loading' +import type { PluginPayload } from '../types' +import { + useAddPluginCredentialHook, + useGetPluginCredentialSchemaHook, + useUpdatePluginCredentialHook, +} from '../hooks/use-credential' +import { useRenderI18nObject } from '@/hooks/use-i18n' + +export type ApiKeyModalProps = { + pluginPayload: PluginPayload + onClose?: () => void + editValues?: Record + onRemove?: () => void + disabled?: boolean + onUpdate?: () => void +} +const ApiKeyModal = ({ + pluginPayload, + onClose, + editValues, + onRemove, + disabled, + onUpdate, +}: ApiKeyModalProps) => { + const { t } = useTranslation() + const { notify } = useToastContext() + const [doingAction, setDoingAction] = useState(false) + const doingActionRef = useRef(doingAction) + const handleSetDoingAction = useCallback((value: boolean) => { + doingActionRef.current = value + setDoingAction(value) + }, []) + const { data = [], isLoading } = useGetPluginCredentialSchemaHook(pluginPayload, CredentialTypeEnum.API_KEY) + const formSchemas = useMemo(() => { + return [ + { + type: FormTypeEnum.textInput, + name: '__name__', + label: t('plugin.auth.authorizationName'), + required: false, + }, + ...data, + ] + }, [data, t]) + const defaultValues = formSchemas.reduce((acc, schema) => { + if (schema.default) + acc[schema.name] = schema.default + return acc + }, {} as Record) + const helpField = formSchemas.find(schema => schema.url && schema.help) + const renderI18nObject = useRenderI18nObject() + const { mutateAsync: addPluginCredential } = useAddPluginCredentialHook(pluginPayload) + const { mutateAsync: updatePluginCredential } = useUpdatePluginCredentialHook(pluginPayload) + const formRef = useRef(null) + const handleConfirm = useCallback(async () => { + if (doingActionRef.current) + return + const { + isCheckValidated, + values, + } = formRef.current?.getFormValues({ + needCheckValidatedValues: true, + needTransformWhenSecretFieldIsPristine: true, + }) || { isCheckValidated: false, values: {} } + if (!isCheckValidated) + return + + try { + const { + __name__, + __credential_id__, + ...restValues + } = values + + handleSetDoingAction(true) + if (editValues) { + await updatePluginCredential({ + credentials: restValues, + credential_id: __credential_id__, + name: __name__ || '', + }) + } + else { + await addPluginCredential({ + credentials: restValues, + type: CredentialTypeEnum.API_KEY, + name: __name__ || '', + }) + } + notify({ + type: 'success', + message: t('common.api.actionSuccess'), + }) + + onClose?.() + onUpdate?.() + } + finally { + handleSetDoingAction(false) + } + }, [addPluginCredential, onClose, onUpdate, updatePluginCredential, notify, t, editValues, handleSetDoingAction]) + + return ( + + + {renderI18nObject(helpField?.help as any)} + + + + ) + } + bottomSlot={ +
+ + {t('common.modelProvider.encrypted.front')} + + PKCS1_OAEP + + {t('common.modelProvider.encrypted.back')} +
+ } + onConfirm={handleConfirm} + showExtraButton={!!editValues} + onExtraButtonClick={onRemove} + disabled={disabled || isLoading || doingAction} + > + { + isLoading && ( +
+ +
+ ) + } + { + !isLoading && !!data.length && ( + + ) + } +
+ ) +} + +export default memo(ApiKeyModal) diff --git a/web/app/components/plugins/plugin-auth/authorize/index.tsx b/web/app/components/plugins/plugin-auth/authorize/index.tsx new file mode 100644 index 000000000..f430d8d48 --- /dev/null +++ b/web/app/components/plugins/plugin-auth/authorize/index.tsx @@ -0,0 +1,104 @@ +import { + memo, + useMemo, +} from 'react' +import { useTranslation } from 'react-i18next' +import AddOAuthButton from './add-oauth-button' +import type { AddOAuthButtonProps } from './add-oauth-button' +import AddApiKeyButton from './add-api-key-button' +import type { AddApiKeyButtonProps } from './add-api-key-button' +import type { PluginPayload } from '../types' + +type AuthorizeProps = { + pluginPayload: PluginPayload + theme?: 'primary' | 'secondary' + showDivider?: boolean + canOAuth?: boolean + canApiKey?: boolean + disabled?: boolean + onUpdate?: () => void +} +const Authorize = ({ + pluginPayload, + theme = 'primary', + showDivider = true, + canOAuth, + canApiKey, + disabled, + onUpdate, +}: AuthorizeProps) => { + const { t } = useTranslation() + const oAuthButtonProps: AddOAuthButtonProps = useMemo(() => { + if (theme === 'secondary') { + return { + buttonText: !canApiKey ? t('plugin.auth.useOAuthAuth') : t('plugin.auth.addOAuth'), + buttonVariant: 'secondary', + className: 'hover:bg-components-button-secondary-bg', + buttonLeftClassName: 'hover:bg-components-button-secondary-bg-hover', + buttonRightClassName: 'hover:bg-components-button-secondary-bg-hover', + dividerClassName: 'bg-divider-regular opacity-100', + pluginPayload, + } + } + + return { + buttonText: !canApiKey ? t('plugin.auth.useOAuthAuth') : t('plugin.auth.addOAuth'), + pluginPayload, + } + }, [canApiKey, theme, pluginPayload, t]) + + const apiKeyButtonProps: AddApiKeyButtonProps = useMemo(() => { + if (theme === 'secondary') { + return { + pluginPayload, + buttonVariant: 'secondary', + buttonText: !canOAuth ? t('plugin.auth.useApiAuth') : t('plugin.auth.addApi'), + } + } + return { + pluginPayload, + buttonText: !canOAuth ? t('plugin.auth.useApiAuth') : t('plugin.auth.addApi'), + buttonVariant: !canOAuth ? 'primary' : 'secondary-accent', + } + }, [canOAuth, theme, pluginPayload, t]) + + return ( + <> +
+ { + canOAuth && ( +
+ +
+ ) + } + { + showDivider && canOAuth && canApiKey && ( +
+
+ or +
+
+ ) + } + { + canApiKey && ( +
+ +
+ ) + } +
+ + ) +} + +export default memo(Authorize) diff --git a/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx b/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx new file mode 100644 index 000000000..14c7ed957 --- /dev/null +++ b/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx @@ -0,0 +1,188 @@ +import { + memo, + useCallback, + useRef, + useState, +} from 'react' +import { RiExternalLinkLine } from '@remixicon/react' +import { + useForm, + useStore, +} from '@tanstack/react-form' +import { useTranslation } from 'react-i18next' +import Modal from '@/app/components/base/modal/modal' +import { + useDeletePluginOAuthCustomClientHook, + useInvalidPluginOAuthClientSchemaHook, + useSetPluginOAuthCustomClientHook, +} from '../hooks/use-credential' +import type { PluginPayload } from '../types' +import AuthForm from '@/app/components/base/form/form-scenarios/auth' +import type { + FormRefObject, + FormSchema, +} from '@/app/components/base/form/types' +import { useToastContext } from '@/app/components/base/toast' +import Button from '@/app/components/base/button' +import { useRenderI18nObject } from '@/hooks/use-i18n' + +type OAuthClientSettingsProps = { + pluginPayload: PluginPayload + onClose?: () => void + editValues?: Record + disabled?: boolean + schemas: FormSchema[] + onAuth?: () => Promise + hasOriginalClientParams?: boolean + onUpdate?: () => void +} +const OAuthClientSettings = ({ + pluginPayload, + onClose, + editValues, + disabled, + schemas, + onAuth, + hasOriginalClientParams, + onUpdate, +}: OAuthClientSettingsProps) => { + const { t } = useTranslation() + const { notify } = useToastContext() + const [doingAction, setDoingAction] = useState(false) + const doingActionRef = useRef(doingAction) + const handleSetDoingAction = useCallback((value: boolean) => { + doingActionRef.current = value + setDoingAction(value) + }, []) + const defaultValues = schemas.reduce((acc, schema) => { + if (schema.default) + acc[schema.name] = schema.default + return acc + }, {} as Record) + const { mutateAsync: setPluginOAuthCustomClient } = useSetPluginOAuthCustomClientHook(pluginPayload) + const invalidPluginOAuthClientSchema = useInvalidPluginOAuthClientSchemaHook(pluginPayload) + const formRef = useRef(null) + const handleConfirm = useCallback(async () => { + if (doingActionRef.current) + return + + try { + const { + isCheckValidated, + values, + } = formRef.current?.getFormValues({ + needCheckValidatedValues: true, + needTransformWhenSecretFieldIsPristine: true, + }) || { isCheckValidated: false, values: {} } + if (!isCheckValidated) + throw new Error('error') + const { + __oauth_client__, + ...restValues + } = values + + handleSetDoingAction(true) + await setPluginOAuthCustomClient({ + client_params: restValues, + enable_oauth_custom_client: __oauth_client__ === 'custom', + }) + notify({ + type: 'success', + message: t('common.api.actionSuccess'), + }) + + onClose?.() + onUpdate?.() + invalidPluginOAuthClientSchema() + } + finally { + handleSetDoingAction(false) + } + }, [onClose, onUpdate, invalidPluginOAuthClientSchema, setPluginOAuthCustomClient, notify, t, handleSetDoingAction]) + + const handleConfirmAndAuthorize = useCallback(async () => { + await handleConfirm() + if (onAuth) + await onAuth() + }, [handleConfirm, onAuth]) + const { mutateAsync: deletePluginOAuthCustomClient } = useDeletePluginOAuthCustomClientHook(pluginPayload) + const handleRemove = useCallback(async () => { + if (doingActionRef.current) + return + + try { + handleSetDoingAction(true) + await deletePluginOAuthCustomClient() + notify({ + type: 'success', + message: t('common.api.actionSuccess'), + }) + onClose?.() + onUpdate?.() + invalidPluginOAuthClientSchema() + } + finally { + handleSetDoingAction(false) + } + }, [onUpdate, invalidPluginOAuthClientSchema, deletePluginOAuthCustomClient, notify, t, handleSetDoingAction, onClose]) + const form = useForm({ + defaultValues: editValues || defaultValues, + }) + const __oauth_client__ = useStore(form.store, s => s.values.__oauth_client__) + const helpField = schemas.find(schema => schema.url && schema.help) + const renderI18nObject = useRenderI18nObject() + return ( + + + + ) + } + > + <> + + { + helpField && __oauth_client__ === 'custom' && ( + + + {renderI18nObject(helpField?.help as any)} + + + + )} + + + ) +} + +export default memo(OAuthClientSettings) diff --git a/web/app/components/plugins/plugin-auth/authorized-in-node.tsx b/web/app/components/plugins/plugin-auth/authorized-in-node.tsx new file mode 100644 index 000000000..79189fa58 --- /dev/null +++ b/web/app/components/plugins/plugin-auth/authorized-in-node.tsx @@ -0,0 +1,113 @@ +import { + memo, + useCallback, + useState, +} from 'react' +import { useTranslation } from 'react-i18next' +import { RiArrowDownSLine } from '@remixicon/react' +import Button from '@/app/components/base/button' +import Indicator from '@/app/components/header/indicator' +import cn from '@/utils/classnames' +import type { + Credential, + PluginPayload, +} from './types' +import { + Authorized, + usePluginAuth, +} from '.' + +type AuthorizedInNodeProps = { + pluginPayload: PluginPayload + onAuthorizationItemClick: (id: string) => void + credentialId?: string +} +const AuthorizedInNode = ({ + pluginPayload, + onAuthorizationItemClick, + credentialId, +}: AuthorizedInNodeProps) => { + const { t } = useTranslation() + const [isOpen, setIsOpen] = useState(false) + const { + canApiKey, + canOAuth, + credentials, + disabled, + invalidPluginCredentialInfo, + } = usePluginAuth(pluginPayload, isOpen || !!credentialId) + const renderTrigger = useCallback((open?: boolean) => { + let label = '' + let removed = false + if (!credentialId) { + label = t('plugin.auth.workspaceDefault') + } + else { + const credential = credentials.find(c => c.id === credentialId) + label = credential ? credential.name : t('plugin.auth.authRemoved') + removed = !credential + } + return ( + + ) + }, [credentialId, credentials, t]) + const extraAuthorizationItems: Credential[] = [ + { + id: '__workspace_default__', + name: t('plugin.auth.workspaceDefault'), + provider: '', + is_default: !credentialId, + isWorkspaceDefault: true, + }, + ] + const handleAuthorizationItemClick = useCallback((id: string) => { + onAuthorizationItemClick(id) + setIsOpen(false) + }, [ + onAuthorizationItemClick, + setIsOpen, + ]) + + return ( + + ) +} + +export default memo(AuthorizedInNode) diff --git a/web/app/components/plugins/plugin-auth/authorized/index.tsx b/web/app/components/plugins/plugin-auth/authorized/index.tsx new file mode 100644 index 000000000..ac771afdd --- /dev/null +++ b/web/app/components/plugins/plugin-auth/authorized/index.tsx @@ -0,0 +1,342 @@ +import { + memo, + useCallback, + useRef, + useState, +} from 'react' +import { + RiArrowDownSLine, +} from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import type { + PortalToFollowElemOptions, +} from '@/app/components/base/portal-to-follow-elem' +import Button from '@/app/components/base/button' +import Indicator from '@/app/components/header/indicator' +import cn from '@/utils/classnames' +import Confirm from '@/app/components/base/confirm' +import Authorize from '../authorize' +import type { Credential } from '../types' +import { CredentialTypeEnum } from '../types' +import ApiKeyModal from '../authorize/api-key-modal' +import Item from './item' +import { useToastContext } from '@/app/components/base/toast' +import type { PluginPayload } from '../types' +import { + useDeletePluginCredentialHook, + useSetPluginDefaultCredentialHook, + useUpdatePluginCredentialHook, +} from '../hooks/use-credential' + +type AuthorizedProps = { + pluginPayload: PluginPayload + credentials: Credential[] + canOAuth?: boolean + canApiKey?: boolean + disabled?: boolean + renderTrigger?: (open?: boolean) => React.ReactNode + isOpen?: boolean + onOpenChange?: (open: boolean) => void + offset?: PortalToFollowElemOptions['offset'] + placement?: PortalToFollowElemOptions['placement'] + triggerPopupSameWidth?: boolean + popupClassName?: string + disableSetDefault?: boolean + onItemClick?: (id: string) => void + extraAuthorizationItems?: Credential[] + showItemSelectedIcon?: boolean + selectedCredentialId?: string + onUpdate?: () => void +} +const Authorized = ({ + pluginPayload, + credentials, + canOAuth, + canApiKey, + disabled, + renderTrigger, + isOpen, + onOpenChange, + offset = 8, + placement = 'bottom-start', + triggerPopupSameWidth = true, + popupClassName, + disableSetDefault, + onItemClick, + extraAuthorizationItems, + showItemSelectedIcon, + selectedCredentialId, + onUpdate, +}: AuthorizedProps) => { + const { t } = useTranslation() + const { notify } = useToastContext() + const [isLocalOpen, setIsLocalOpen] = useState(false) + const mergedIsOpen = isOpen ?? isLocalOpen + const setMergedIsOpen = useCallback((open: boolean) => { + if (onOpenChange) + onOpenChange(open) + + setIsLocalOpen(open) + }, [onOpenChange]) + const oAuthCredentials = credentials.filter(credential => credential.credential_type === CredentialTypeEnum.OAUTH2) + const apiKeyCredentials = credentials.filter(credential => credential.credential_type === CredentialTypeEnum.API_KEY) + const pendingOperationCredentialId = useRef(null) + const [deleteCredentialId, setDeleteCredentialId] = useState(null) + const { mutateAsync: deletePluginCredential } = useDeletePluginCredentialHook(pluginPayload) + const openConfirm = useCallback((credentialId?: string) => { + if (credentialId) + pendingOperationCredentialId.current = credentialId + + setDeleteCredentialId(pendingOperationCredentialId.current) + }, []) + const closeConfirm = useCallback(() => { + setDeleteCredentialId(null) + pendingOperationCredentialId.current = null + }, []) + const [doingAction, setDoingAction] = useState(false) + const doingActionRef = useRef(doingAction) + const handleSetDoingAction = useCallback((doing: boolean) => { + doingActionRef.current = doing + setDoingAction(doing) + }, []) + const handleConfirm = useCallback(async () => { + if (doingActionRef.current) + return + if (!pendingOperationCredentialId.current) { + setDeleteCredentialId(null) + return + } + try { + handleSetDoingAction(true) + await deletePluginCredential({ credential_id: pendingOperationCredentialId.current }) + notify({ + type: 'success', + message: t('common.api.actionSuccess'), + }) + onUpdate?.() + setDeleteCredentialId(null) + pendingOperationCredentialId.current = null + } + finally { + handleSetDoingAction(false) + } + }, [deletePluginCredential, onUpdate, notify, t, handleSetDoingAction]) + const [editValues, setEditValues] = useState | null>(null) + const handleEdit = useCallback((id: string, values: Record) => { + pendingOperationCredentialId.current = id + setEditValues(values) + }, []) + const handleRemove = useCallback(() => { + setDeleteCredentialId(pendingOperationCredentialId.current) + }, []) + const { mutateAsync: setPluginDefaultCredential } = useSetPluginDefaultCredentialHook(pluginPayload) + const handleSetDefault = useCallback(async (id: string) => { + if (doingActionRef.current) + return + try { + handleSetDoingAction(true) + await setPluginDefaultCredential(id) + notify({ + type: 'success', + message: t('common.api.actionSuccess'), + }) + onUpdate?.() + } + finally { + handleSetDoingAction(false) + } + }, [setPluginDefaultCredential, onUpdate, notify, t, handleSetDoingAction]) + const { mutateAsync: updatePluginCredential } = useUpdatePluginCredentialHook(pluginPayload) + const handleRename = useCallback(async (payload: { + credential_id: string + name: string + }) => { + if (doingActionRef.current) + return + try { + handleSetDoingAction(true) + await updatePluginCredential(payload) + notify({ + type: 'success', + message: t('common.api.actionSuccess'), + }) + onUpdate?.() + } + finally { + handleSetDoingAction(false) + } + }, [updatePluginCredential, notify, t, handleSetDoingAction, onUpdate]) + + return ( + <> + + setMergedIsOpen(!mergedIsOpen)} + asChild + > + { + renderTrigger + ? renderTrigger(mergedIsOpen) + : ( + + ) + } + + +
+
+ { + !!extraAuthorizationItems?.length && ( +
+ { + extraAuthorizationItems.map(credential => ( + + )) + } +
+ ) + } + { + !!oAuthCredentials.length && ( +
+
+ OAuth +
+ { + oAuthCredentials.map(credential => ( + + )) + } +
+ ) + } + { + !!apiKeyCredentials.length && ( +
+
+ API Keys +
+ { + apiKeyCredentials.map(credential => ( + + )) + } +
+ ) + } +
+
+
+ +
+
+
+
+ { + deleteCredentialId && ( + + ) + } + { + !!editValues && ( + { + setEditValues(null) + pendingOperationCredentialId.current = null + }} + onRemove={handleRemove} + disabled={disabled || doingAction} + onUpdate={onUpdate} + /> + ) + } + + ) +} + +export default memo(Authorized) diff --git a/web/app/components/plugins/plugin-auth/authorized/item.tsx b/web/app/components/plugins/plugin-auth/authorized/item.tsx new file mode 100644 index 000000000..5508bcc32 --- /dev/null +++ b/web/app/components/plugins/plugin-auth/authorized/item.tsx @@ -0,0 +1,219 @@ +import { + memo, + useMemo, + useState, +} from 'react' +import { useTranslation } from 'react-i18next' +import { + RiCheckLine, + RiDeleteBinLine, + RiEditLine, + RiEqualizer2Line, +} from '@remixicon/react' +import Indicator from '@/app/components/header/indicator' +import Badge from '@/app/components/base/badge' +import ActionButton from '@/app/components/base/action-button' +import Tooltip from '@/app/components/base/tooltip' +import Button from '@/app/components/base/button' +import Input from '@/app/components/base/input' +import cn from '@/utils/classnames' +import type { Credential } from '../types' +import { CredentialTypeEnum } from '../types' + +type ItemProps = { + credential: Credential + disabled?: boolean + onDelete?: (id: string) => void + onEdit?: (id: string, values: Record) => void + onSetDefault?: (id: string) => void + onRename?: (payload: { + credential_id: string + name: string + }) => void + disableRename?: boolean + disableEdit?: boolean + disableDelete?: boolean + disableSetDefault?: boolean + onItemClick?: (id: string) => void + showSelectedIcon?: boolean + selectedCredentialId?: string +} +const Item = ({ + credential, + disabled, + onDelete, + onEdit, + onSetDefault, + onRename, + disableRename, + disableEdit, + disableDelete, + disableSetDefault, + onItemClick, + showSelectedIcon, + selectedCredentialId, +}: ItemProps) => { + const { t } = useTranslation() + const [renaming, setRenaming] = useState(false) + const [renameValue, setRenameValue] = useState(credential.name) + const isOAuth = credential.credential_type === CredentialTypeEnum.OAUTH2 + const showAction = useMemo(() => { + return !(disableRename && disableEdit && disableDelete && disableSetDefault) + }, [disableRename, disableEdit, disableDelete, disableSetDefault]) + + return ( +
onItemClick?.(credential.id === '__workspace_default__' ? '' : credential.id)} + > + { + renaming && ( +
+ setRenameValue(e.target.value)} + placeholder={t('common.placeholder.input')} + onClick={e => e.stopPropagation()} + /> + + +
+ ) + } + { + !renaming && ( +
+ { + showSelectedIcon && ( +
+ { + selectedCredentialId === credential.id && ( + + ) + } +
+ ) + } + +
+ {credential.name} +
+ { + credential.is_default && ( + + {t('plugin.auth.default')} + + ) + } +
+ ) + } + { + showAction && !renaming && ( +
+ { + !credential.is_default && !disableSetDefault && ( + + ) + } + { + !disableRename && ( + + { + e.stopPropagation() + setRenaming(true) + setRenameValue(credential.name) + }} + > + + + + ) + } + { + !isOAuth && !disableEdit && ( + + { + e.stopPropagation() + onEdit?.( + credential.id, + { + ...credential.credentials, + __name__: credential.name, + __credential_id__: credential.id, + }, + ) + }} + > + + + + ) + } + { + !disableDelete && ( + + { + e.stopPropagation() + onDelete?.(credential.id) + }} + > + + + + ) + } +
+ ) + } +
+ ) +} + +export default memo(Item) diff --git a/web/app/components/plugins/plugin-auth/hooks/use-credential.ts b/web/app/components/plugins/plugin-auth/hooks/use-credential.ts new file mode 100644 index 000000000..5a7a497ad --- /dev/null +++ b/web/app/components/plugins/plugin-auth/hooks/use-credential.ts @@ -0,0 +1,88 @@ +import { + useAddPluginCredential, + useDeletePluginCredential, + useDeletePluginOAuthCustomClient, + useGetPluginCredentialInfo, + useGetPluginCredentialSchema, + useGetPluginOAuthClientSchema, + useGetPluginOAuthUrl, + useInvalidPluginCredentialInfo, + useInvalidPluginOAuthClientSchema, + useSetPluginDefaultCredential, + useSetPluginOAuthCustomClient, + useUpdatePluginCredential, +} from '@/service/use-plugins-auth' +import { useGetApi } from './use-get-api' +import type { PluginPayload } from '../types' +import type { CredentialTypeEnum } from '../types' + +export const useGetPluginCredentialInfoHook = (pluginPayload: PluginPayload, enable?: boolean) => { + const apiMap = useGetApi(pluginPayload) + return useGetPluginCredentialInfo(enable ? apiMap.getCredentialInfo : '') +} + +export const useDeletePluginCredentialHook = (pluginPayload: PluginPayload) => { + const apiMap = useGetApi(pluginPayload) + + return useDeletePluginCredential(apiMap.deleteCredential) +} + +export const useInvalidPluginCredentialInfoHook = (pluginPayload: PluginPayload) => { + const apiMap = useGetApi(pluginPayload) + + return useInvalidPluginCredentialInfo(apiMap.getCredentialInfo) +} + +export const useSetPluginDefaultCredentialHook = (pluginPayload: PluginPayload) => { + const apiMap = useGetApi(pluginPayload) + + return useSetPluginDefaultCredential(apiMap.setDefaultCredential) +} + +export const useGetPluginCredentialSchemaHook = (pluginPayload: PluginPayload, credentialType: CredentialTypeEnum) => { + const apiMap = useGetApi(pluginPayload) + + return useGetPluginCredentialSchema(apiMap.getCredentialSchema(credentialType)) +} + +export const useAddPluginCredentialHook = (pluginPayload: PluginPayload) => { + const apiMap = useGetApi(pluginPayload) + + return useAddPluginCredential(apiMap.addCredential) +} + +export const useUpdatePluginCredentialHook = (pluginPayload: PluginPayload) => { + const apiMap = useGetApi(pluginPayload) + + return useUpdatePluginCredential(apiMap.updateCredential) +} + +export const useGetPluginOAuthUrlHook = (pluginPayload: PluginPayload) => { + const apiMap = useGetApi(pluginPayload) + + return useGetPluginOAuthUrl(apiMap.getOauthUrl) +} + +export const useGetPluginOAuthClientSchemaHook = (pluginPayload: PluginPayload) => { + const apiMap = useGetApi(pluginPayload) + + return useGetPluginOAuthClientSchema(apiMap.getOauthClientSchema) +} + +export const useInvalidPluginOAuthClientSchemaHook = (pluginPayload: PluginPayload) => { + const apiMap = useGetApi(pluginPayload) + + return useInvalidPluginOAuthClientSchema(apiMap.getOauthClientSchema) +} + +export const useSetPluginOAuthCustomClientHook = (pluginPayload: PluginPayload) => { + const apiMap = useGetApi(pluginPayload) + + return useSetPluginOAuthCustomClient(apiMap.setCustomOauthClient) +} + +export const useDeletePluginOAuthCustomClientHook = (pluginPayload: PluginPayload) => { + const apiMap = useGetApi(pluginPayload) + + return useDeletePluginOAuthCustomClient(apiMap.deleteCustomOAuthClient) +} diff --git a/web/app/components/plugins/plugin-auth/hooks/use-get-api.ts b/web/app/components/plugins/plugin-auth/hooks/use-get-api.ts new file mode 100644 index 000000000..14199ddc4 --- /dev/null +++ b/web/app/components/plugins/plugin-auth/hooks/use-get-api.ts @@ -0,0 +1,41 @@ +import { + AuthCategory, +} from '../types' +import type { + CredentialTypeEnum, + PluginPayload, +} from '../types' + +export const useGetApi = ({ category = AuthCategory.tool, provider }: PluginPayload) => { + if (category === AuthCategory.tool) { + return { + getCredentialInfo: `/workspaces/current/tool-provider/builtin/${provider}/credential/info`, + setDefaultCredential: `/workspaces/current/tool-provider/builtin/${provider}/default-credential`, + getCredentials: `/workspaces/current/tool-provider/builtin/${provider}/credentials`, + addCredential: `/workspaces/current/tool-provider/builtin/${provider}/add`, + updateCredential: `/workspaces/current/tool-provider/builtin/${provider}/update`, + deleteCredential: `/workspaces/current/tool-provider/builtin/${provider}/delete`, + getCredentialSchema: (credential_type: CredentialTypeEnum) => `/workspaces/current/tool-provider/builtin/${provider}/credential/schema/${credential_type}`, + getOauthUrl: `/oauth/plugin/${provider}/tool/authorization-url`, + getOauthClientSchema: `/workspaces/current/tool-provider/builtin/${provider}/oauth/client-schema`, + setCustomOauthClient: `/workspaces/current/tool-provider/builtin/${provider}/oauth/custom-client`, + getCustomOAuthClientValues: `/workspaces/current/tool-provider/builtin/${provider}/oauth/custom-client`, + deleteCustomOAuthClient: `/workspaces/current/tool-provider/builtin/${provider}/oauth/custom-client`, + } + } + + return { + getCredentialInfo: '', + setDefaultCredential: '', + getCredentials: '', + addCredential: '', + updateCredential: '', + deleteCredential: '', + getCredentialSchema: () => '', + getOauthUrl: '', + getOauthClientSchema: '', + setCustomOauthClient: '', + getCustomOAuthClientValues: '', + deleteCustomOAuthClient: '', + } +} diff --git a/web/app/components/plugins/plugin-auth/hooks/use-plugin-auth.ts b/web/app/components/plugins/plugin-auth/hooks/use-plugin-auth.ts new file mode 100644 index 000000000..e449a4bb6 --- /dev/null +++ b/web/app/components/plugins/plugin-auth/hooks/use-plugin-auth.ts @@ -0,0 +1,25 @@ +import { useAppContext } from '@/context/app-context' +import { + useGetPluginCredentialInfoHook, + useInvalidPluginCredentialInfoHook, +} from './use-credential' +import { CredentialTypeEnum } from '../types' +import type { PluginPayload } from '../types' + +export const usePluginAuth = (pluginPayload: PluginPayload, enable?: boolean) => { + const { data } = useGetPluginCredentialInfoHook(pluginPayload, enable) + const { isCurrentWorkspaceManager } = useAppContext() + const isAuthorized = !!data?.credentials.length + const canOAuth = data?.supported_credential_types.includes(CredentialTypeEnum.OAUTH2) + const canApiKey = data?.supported_credential_types.includes(CredentialTypeEnum.API_KEY) + const invalidPluginCredentialInfo = useInvalidPluginCredentialInfoHook(pluginPayload) + + return { + isAuthorized, + canOAuth, + canApiKey, + credentials: data?.credentials || [], + disabled: !isCurrentWorkspaceManager, + invalidPluginCredentialInfo, + } +} diff --git a/web/app/components/plugins/plugin-auth/index.tsx b/web/app/components/plugins/plugin-auth/index.tsx new file mode 100644 index 000000000..e4f6ae8b2 --- /dev/null +++ b/web/app/components/plugins/plugin-auth/index.tsx @@ -0,0 +1,6 @@ +export { default as PluginAuth } from './plugin-auth' +export { default as Authorized } from './authorized' +export { default as AuthorizedInNode } from './authorized-in-node' +export { default as PluginAuthInAgent } from './plugin-auth-in-agent' +export { usePluginAuth } from './hooks/use-plugin-auth' +export * from './types' diff --git a/web/app/components/plugins/plugin-auth/plugin-auth-in-agent.tsx b/web/app/components/plugins/plugin-auth/plugin-auth-in-agent.tsx new file mode 100644 index 000000000..f3557f3d6 --- /dev/null +++ b/web/app/components/plugins/plugin-auth/plugin-auth-in-agent.tsx @@ -0,0 +1,123 @@ +import { + memo, + useCallback, + useState, +} from 'react' +import { RiArrowDownSLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import Authorize from './authorize' +import Authorized from './authorized' +import type { + Credential, + PluginPayload, +} from './types' +import { usePluginAuth } from './hooks/use-plugin-auth' +import Button from '@/app/components/base/button' +import Indicator from '@/app/components/header/indicator' +import cn from '@/utils/classnames' + +type PluginAuthInAgentProps = { + pluginPayload: PluginPayload + credentialId?: string + onAuthorizationItemClick?: (id: string) => void +} +const PluginAuthInAgent = ({ + pluginPayload, + credentialId, + onAuthorizationItemClick, +}: PluginAuthInAgentProps) => { + const { t } = useTranslation() + const [isOpen, setIsOpen] = useState(false) + const { + isAuthorized, + canOAuth, + canApiKey, + credentials, + disabled, + invalidPluginCredentialInfo, + } = usePluginAuth(pluginPayload, true) + + const extraAuthorizationItems: Credential[] = [ + { + id: '__workspace_default__', + name: t('plugin.auth.workspaceDefault'), + provider: '', + is_default: !credentialId, + isWorkspaceDefault: true, + }, + ] + + const handleAuthorizationItemClick = useCallback((id: string) => { + onAuthorizationItemClick?.(id) + setIsOpen(false) + }, [ + onAuthorizationItemClick, + setIsOpen, + ]) + + const renderTrigger = useCallback((isOpen?: boolean) => { + let label = '' + let removed = false + if (!credentialId) { + label = t('plugin.auth.workspaceDefault') + } + else { + const credential = credentials.find(c => c.id === credentialId) + label = credential ? credential.name : t('plugin.auth.authRemoved') + removed = !credential + } + return ( + + ) + }, [credentialId, credentials, t]) + + return ( + <> + { + !isAuthorized && ( + + ) + } + { + isAuthorized && ( + + ) + } + + ) +} + +export default memo(PluginAuthInAgent) diff --git a/web/app/components/plugins/plugin-auth/plugin-auth.tsx b/web/app/components/plugins/plugin-auth/plugin-auth.tsx new file mode 100644 index 000000000..76b405a75 --- /dev/null +++ b/web/app/components/plugins/plugin-auth/plugin-auth.tsx @@ -0,0 +1,59 @@ +import { memo } from 'react' +import Authorize from './authorize' +import Authorized from './authorized' +import type { PluginPayload } from './types' +import { usePluginAuth } from './hooks/use-plugin-auth' +import cn from '@/utils/classnames' + +type PluginAuthProps = { + pluginPayload: PluginPayload + children?: React.ReactNode + className?: string +} +const PluginAuth = ({ + pluginPayload, + children, + className, +}: PluginAuthProps) => { + const { + isAuthorized, + canOAuth, + canApiKey, + credentials, + disabled, + invalidPluginCredentialInfo, + } = usePluginAuth(pluginPayload, !!pluginPayload.provider) + + return ( +
+ { + !isAuthorized && ( + + ) + } + { + isAuthorized && !children && ( + + ) + } + { + isAuthorized && children + } +
+ ) +} + +export default memo(PluginAuth) diff --git a/web/app/components/plugins/plugin-auth/types.ts b/web/app/components/plugins/plugin-auth/types.ts new file mode 100644 index 000000000..ad41733bd --- /dev/null +++ b/web/app/components/plugins/plugin-auth/types.ts @@ -0,0 +1,25 @@ +export enum AuthCategory { + tool = 'tool', + datasource = 'datasource', + model = 'model', +} + +export type PluginPayload = { + category: AuthCategory + provider: string +} + +export enum CredentialTypeEnum { + OAUTH2 = 'oauth2', + API_KEY = 'api-key', +} + +export type Credential = { + id: string + name: string + provider: string + credential_type?: CredentialTypeEnum + is_default: boolean + credentials?: Record + isWorkspaceDefault?: boolean +} diff --git a/web/app/components/plugins/plugin-auth/utils.ts b/web/app/components/plugins/plugin-auth/utils.ts new file mode 100644 index 000000000..d264cfb19 --- /dev/null +++ b/web/app/components/plugins/plugin-auth/utils.ts @@ -0,0 +1,10 @@ +export const transformFormSchemasSecretInput = (isPristineSecretInputNames: string[], values: Record) => { + const transformedValues: Record = { ...values } + + isPristineSecretInputNames.forEach((name) => { + if (transformedValues[name]) + transformedValues[name] = '[__HIDDEN__]' + }) + + return transformedValues +} diff --git a/web/app/components/plugins/plugin-detail-panel/action-list.tsx b/web/app/components/plugins/plugin-detail-panel/action-list.tsx index 2505b6d5a..040c72863 100644 --- a/web/app/components/plugins/plugin-detail-panel/action-list.tsx +++ b/web/app/components/plugins/plugin-detail-panel/action-list.tsx @@ -1,17 +1,9 @@ -import React, { useMemo, useState } from 'react' +import React, { useMemo } from 'react' import { useTranslation } from 'react-i18next' -import { useAppContext } from '@/context/app-context' -import Button from '@/app/components/base/button' -import Toast from '@/app/components/base/toast' -import Indicator from '@/app/components/header/indicator' import ToolItem from '@/app/components/tools/provider/tool-item' -import ConfigCredential from '@/app/components/tools/setting/build-in/config-credentials' import { useAllToolProviders, useBuiltinTools, - useInvalidateAllToolProviders, - useRemoveProviderCredentials, - useUpdateProviderCredentials, } from '@/service/use-tools' import type { PluginDetail } from '@/app/components/plugins/types' @@ -23,35 +15,14 @@ const ActionList = ({ detail, }: Props) => { const { t } = useTranslation() - const { isCurrentWorkspaceManager } = useAppContext() const providerBriefInfo = detail.declaration.tool.identity const providerKey = `${detail.plugin_id}/${providerBriefInfo.name}` const { data: collectionList = [] } = useAllToolProviders() - const invalidateAllToolProviders = useInvalidateAllToolProviders() const provider = useMemo(() => { return collectionList.find(collection => collection.name === providerKey) }, [collectionList, providerKey]) const { data } = useBuiltinTools(providerKey) - const [showSettingAuth, setShowSettingAuth] = useState(false) - - const handleCredentialSettingUpdate = () => { - invalidateAllToolProviders() - Toast.notify({ - type: 'success', - message: t('common.api.actionSuccess'), - }) - setShowSettingAuth(false) - } - - const { mutate: updatePermission, isPending } = useUpdateProviderCredentials({ - onSuccess: handleCredentialSettingUpdate, - }) - - const { mutate: removePermission } = useRemoveProviderCredentials({ - onSuccess: handleCredentialSettingUpdate, - }) - if (!data || !provider) return null @@ -60,26 +31,7 @@ const ActionList = ({
{t('plugin.detailPanel.actionNum', { num: data.length, action: data.length > 1 ? 'actions' : 'action' })} - {provider.is_team_authorization && provider.allow_delete && ( - - )}
- {!provider.is_team_authorization && provider.allow_delete && ( - - )}
{data.map(tool => ( @@ -93,18 +45,6 @@ const ActionList = ({ /> ))}
- {showSettingAuth && ( - setShowSettingAuth(false)} - onSaved={async value => updatePermission({ - providerName: provider.name, - credentials: value, - })} - onRemove={async () => removePermission(provider.name)} - isSaving={isPending} - /> - )} ) } diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header.tsx b/web/app/components/plugins/plugin-detail-panel/detail-header.tsx index 0a5a8b87d..124e133c2 100644 --- a/web/app/components/plugins/plugin-detail-panel/detail-header.tsx +++ b/web/app/components/plugins/plugin-detail-panel/detail-header.tsx @@ -36,6 +36,9 @@ import { useInvalidateAllToolProviders } from '@/service/use-tools' import { API_PREFIX } from '@/config' import cn from '@/utils/classnames' import { getMarketplaceUrl } from '@/utils/var' +import { PluginAuth } from '@/app/components/plugins/plugin-auth' +import { AuthCategory } from '@/app/components/plugins/plugin-auth' +import { useAllToolProviders } from '@/service/use-tools' const i18nPrefix = 'plugin.action' @@ -68,7 +71,14 @@ const DetailHeader = ({ meta, plugin_id, } = detail - const { author, category, name, label, description, icon, verified } = detail.declaration + const { author, category, name, label, description, icon, verified, tool } = detail.declaration + const isTool = category === PluginType.tool + const providerBriefInfo = tool?.identity + const providerKey = `${plugin_id}/${providerBriefInfo?.name}` + const { data: collectionList = [] } = useAllToolProviders(isTool) + const provider = useMemo(() => { + return collectionList.find(collection => collection.name === providerKey) + }, [collectionList, providerKey]) const isFromGitHub = source === PluginSource.github const isFromMarketplace = source === PluginSource.marketplace @@ -262,7 +272,17 @@ const DetailHeader = ({ - + + { + category === PluginType.tool && ( + + ) + } {isShowPluginInfo && ( = ({ } as any) } - // authorization - const { isCurrentWorkspaceManager } = useAppContext() - const [isShowSettingAuth, setShowSettingAuth] = useState(false) - const handleCredentialSettingUpdate = () => { - invalidateAllBuiltinTools() - Toast.notify({ - type: 'success', - message: t('common.api.actionSuccess'), - }) - setShowSettingAuth(false) - onShowChange(false) - } - - const { mutate: updatePermission } = useUpdateProviderCredentials({ - onSuccess: handleCredentialSettingUpdate, - }) - // install from marketplace const currentTool = useMemo(() => { return currentProvider?.tools.find(tool => tool.name === value?.tool_name) @@ -226,6 +203,12 @@ const ToolSelector: FC = ({ invalidateAllBuiltinTools() invalidateInstalledPluginList() } + const handleAuthorizationItemClick = (id: string) => { + onSelect({ + ...value, + credential_id: id, + } as any) + } return ( <> @@ -264,7 +247,6 @@ const ToolSelector: FC = ({ onSwitchChange={handleEnabledChange} onDelete={onDelete} noAuth={currentProvider && currentTool && !currentProvider.is_team_authorization} - onAuth={() => setShowSettingAuth(true)} uninstalled={!currentProvider && inMarketPlace} versionMismatch={currentProvider && inMarketPlace && !currentTool} installInfo={manifest?.latest_package_identifier} @@ -284,171 +266,131 @@ const ToolSelector: FC = ({ )}
-
- {!isShowSettingAuth && ( - <> -
{t(`plugin.detailPanel.toolSelector.${isEdit ? 'toolSetting' : 'title'}`)}
- {/* base form */} -
-
-
{t('plugin.detailPanel.toolSelector.toolLabel')}
- - } - isShow={panelShowState || isShowChooseTool} - onShowChange={trigger ? onPanelShowStateChange as any : setIsShowChooseTool} - disabled={false} - supportAddCustomTool - onSelect={handleSelectTool} - onSelectMultiple={handleSelectMultipleTool} - scope={scope} - selectedTools={selectedTools} - canChooseMCPTool={canChooseMCPTool} - /> -
-
-
{t('plugin.detailPanel.toolSelector.descriptionLabel')}
-