Support OAuth Integration for Plugin Tools (#22550)

Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
Maries
2025-07-17 17:18:44 +08:00
committed by GitHub
parent 965e952336
commit a4ef900916
89 changed files with 5516 additions and 875 deletions

View File

@@ -4,7 +4,7 @@ from openai import BaseModel
from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolInvokeFrom
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
class ToolRuntime(BaseModel):
@@ -17,6 +17,7 @@ class ToolRuntime(BaseModel):
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict)
credential_type: CredentialType = Field(default=CredentialType.API_KEY)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)

View File

@@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
from core.tools.entities.tool_entities import (
CredentialType,
OAuthSchema,
ToolEntity,
ToolProviderEntity,
ToolProviderType,
)
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolProviderNotFoundError,
@@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
credentials_schema.append(credential_dict)
oauth_schema = None
if provider_yaml.get("oauth_schema", None) is not None:
oauth_schema = OAuthSchema(
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
)
super().__init__(
entity=ToolProviderEntity(
identity=provider_yaml["identity"],
credentials_schema=credentials_schema,
oauth_schema=oauth_schema,
),
)
@@ -97,10 +111,39 @@ class BuiltinToolProviderController(ToolProviderController):
:return: the credentials schema
"""
if not self.entity.credentials_schema:
return []
return self.get_credentials_schema_by_type(CredentialType.API_KEY.value)
return self.entity.credentials_schema.copy()
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:param credential_type: the type of the credential
:return: the credentials schema of the provider
"""
if credential_type == CredentialType.OAUTH2.value:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == CredentialType.API_KEY.value:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
raise ValueError(f"Invalid credential type: {credential_type}")
def get_oauth_client_schema(self) -> list[ProviderConfig]:
"""
returns the oauth client schema of the provider
:return: the oauth client schema
"""
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
def get_supported_credential_types(self) -> list[str]:
"""
returns the credential support type of the provider
"""
types = []
if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0:
types.append(CredentialType.API_KEY.value)
if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0:
types.append(CredentialType.OAUTH2.value)
return types
def get_tools(self) -> list[BuiltinTool]:
"""
@@ -123,7 +166,11 @@ class BuiltinToolProviderController(ToolProviderController):
:return: whether the provider needs credentials
"""
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
return (
self.entity.credentials_schema is not None
and len(self.entity.credentials_schema) != 0
or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0)
)
@property
def provider_type(self) -> ToolProviderType:

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, field_validator
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.tool_entities import CredentialType, ToolProviderType
class ToolApiEntity(BaseModel):
@@ -87,3 +87,22 @@ class ToolProviderApiEntity(BaseModel):
def optional_field(self, key: str, value: Any) -> dict:
"""Return dict with key-value if value is truthy, empty dict otherwise."""
return {key: value} if value else {}
class ToolProviderCredentialApiEntity(BaseModel):
id: str = Field(description="The unique id of the credential")
name: str = Field(description="The name of the credential")
provider: str = Field(description="The provider of the credential")
credential_type: CredentialType = Field(description="The type of the credential")
is_default: bool = Field(
default=False, description="Whether the credential is the default credential for the provider in the workspace"
)
credentials: dict = Field(description="The credentials of the provider")
class ToolProviderCredentialInfoApiEntity(BaseModel):
supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
is_oauth_custom_client_enabled: bool = Field(
default=False, description="Whether the OAuth custom client is enabled for the provider"
)
credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")

View File

@@ -370,10 +370,18 @@ class ToolEntity(BaseModel):
return v or []
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
credentials_schema: list[ProviderConfig] = Field(
default_factory=list, description="The schema of the OAuth credentials"
)
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: Optional[str] = None
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = None
class ToolProviderEntityWithPlugin(ToolProviderEntity):
@@ -453,6 +461,7 @@ class ToolSelector(BaseModel):
options: Optional[list[PluginParameterOption]] = None
provider_id: str = Field(..., description="The id of the provider")
credential_id: Optional[str] = Field(default=None, description="The id of the credential")
tool_name: str = Field(..., description="The name of the tool")
tool_description: str = Field(..., description="The description of the tool")
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
@@ -460,3 +469,36 @@ class ToolSelector(BaseModel):
def to_plugin_parameter(self) -> dict[str, Any]:
return self.model_dump()
class CredentialType(enum.StrEnum):
API_KEY = "api-key"
OAUTH2 = "oauth2"
def get_name(self):
if self == CredentialType.API_KEY:
return "API KEY"
elif self == CredentialType.OAUTH2:
return "AUTH"
else:
return self.value.replace("-", " ").upper()
def is_editable(self):
return self == CredentialType.API_KEY
def is_validate_allowed(self):
return self == CredentialType.API_KEY
@classmethod
def values(cls):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
type_name = credential_type.lower()
if type_name == "api-key":
return cls.API_KEY
elif type_name == "oauth2":
return cls.OAUTH2
else:
raise ValueError(f"Invalid credential type: {credential_type}")

View File

@@ -44,6 +44,7 @@ class PluginTool(Tool):
tool_provider=self.entity.identity.provider,
tool_name=self.entity.identity.name,
credentials=self.runtime.credentials,
credential_type=self.runtime.credential_type,
tool_parameters=tool_parameters,
conversation_id=conversation_id,
app_id=app_id,

View File

@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from yarl import URL
import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
@@ -17,6 +18,7 @@ from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.entities.variable_pool import VariablePool
from services.tools.mcp_tools_mange_service import MCPToolManageService
@@ -24,7 +26,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -41,16 +42,17 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
CredentialType,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
)
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ProviderConfigEncrypter,
ToolParameterConfigurationManager,
)
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
@@ -68,8 +70,11 @@ class ToolManager:
@classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
"""
get the hardcoded provider
"""
if len(cls._hardcoded_providers) == 0:
# init the builtin providers
cls.load_hardcoded_providers_cache()
@@ -113,7 +118,12 @@ class ToolManager:
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(Lock())
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
with contexts.plugin_tool_providers_lock.get():
# double check
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
@@ -131,25 +141,7 @@ class ToolManager:
)
plugin_tool_providers[provider] = controller
return controller
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
"""
get the builtin tool
:param provider: the name of the provider
:param tool_name: the name of the tool
:param tenant_id: the id of the tenant
:return: the provider, the tool
"""
provider_controller = cls.get_builtin_provider(provider, tenant_id)
tool = provider_controller.get_tool(tool_name)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool
return controller
@classmethod
def get_tool_runtime(
@@ -160,6 +152,7 @@ class ToolManager:
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: Optional[str] = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
"""
get the tool runtime
@@ -170,6 +163,7 @@ class ToolManager:
:param tenant_id: the tenant id
:param invoke_from: invoke from
:param tool_invoke_from: the tool invoke from
:param credential_id: the credential id
:return: the tool
"""
@@ -193,49 +187,70 @@ class ToolManager:
)
),
)
builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id)
# get credentials
builtin_provider: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.first()
)
# get specific credentials
if is_valid_uuid(credential_id):
try:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
except Exception as e:
builtin_provider = None
logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True)
# if the provider has been deleted, raise an error
if builtin_provider is None:
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
# fallback to the default provider
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
# use the default provider
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
# decrypt the credentials
credentials = builtin_provider.credentials
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
],
cache=ToolProviderCredentialsCache(
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
)
decrypted_credentials = tool_configuration.decrypt(credentials)
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=decrypted_credentials,
credentials=encrypter.decrypt(builtin_provider.credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
@@ -245,22 +260,16 @@ class ToolManager:
elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
# decrypt the credentials
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
provider_type=api_provider.provider_type.value,
provider_identity=api_provider.entity.identity.name,
controller=api_provider,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
return cast(
ApiTool,
api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=decrypted_credentials,
credentials=encrypter.decrypt(credentials),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
@@ -320,6 +329,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT,
credential_id=agent_tool.credential_id,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
@@ -362,6 +372,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
credential_id=workflow_tool.credential_id,
)
parameters = tool_runtime.get_merged_runtime_parameters()
@@ -391,6 +402,7 @@ class ToolManager:
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
credential_id: Optional[str] = None,
) -> Tool:
"""
get tool runtime from plugin
@@ -402,6 +414,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.PLUGIN,
credential_id=credential_id,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
@@ -551,6 +564,22 @@ class ToolManager:
return cls._builtin_tools_labels[tool_name]
@classmethod
def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
"""
list all the builtin providers
"""
# according to multi credentials, select the one with is_default=True first, then created_at oldest
# for compatibility with old version
sql = """
SELECT DISTINCT ON (tenant_id, provider) id
FROM tool_builtin_providers
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
@classmethod
def list_providers_from_api(
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
@@ -565,21 +594,13 @@ class ToolManager:
with db.session.no_autoflush:
if "builtin" in filters:
# get builtin providers
builtin_providers = cls.list_builtin_providers(tenant_id)
# get db builtin providers
db_builtin_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
)
# rewrite db_builtin_providers
for db_provider in db_builtin_providers:
tool_provider_id = str(ToolProviderID(db_provider.provider))
db_provider.provider = tool_provider_id
def find_db_builtin_provider(provider):
return next((x for x in db_builtin_providers if x.provider == provider), None)
# key: provider name, value: provider
db_builtin_providers = {
str(ToolProviderID(provider.provider)): provider
for provider in cls.list_default_builtin_providers(tenant_id)
}
# append builtin providers
for provider in builtin_providers:
@@ -591,10 +612,9 @@ class ToolManager:
name_func=lambda x: x.identity.name,
):
continue
user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider,
db_provider=find_db_builtin_provider(provider.entity.identity.name),
db_provider=db_builtin_providers.get(provider.entity.identity.name),
decrypt_credentials=False,
)
@@ -604,7 +624,6 @@ class ToolManager:
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
# get db api providers
if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
@@ -764,15 +783,12 @@ class ToolManager:
auth_type,
)
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
controller=controller,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
try:
icon = json.loads(provider_obj.icon)

View File

@@ -1,12 +1,8 @@
from copy import deepcopy
from typing import Any
from pydantic import BaseModel
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
@@ -14,110 +10,6 @@ from core.tools.entities.tool_entities import (
)
class ProviderConfigEncrypter(BaseModel):
tenant_id: str
config: list[BasicProviderConfig]
provider_type: str
provider_identity: str
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
if use_cache:
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass
if use_cache:
cache.set(data)
return data
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cache.delete()
class ToolParameterConfigurationManager:
"""
Tool parameter configuration manager

View File

@@ -0,0 +1,142 @@
from copy import deepcopy
from typing import Any, Optional, Protocol
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.tools.__base.tool_provider import ToolProviderController
class ProviderConfigCache(Protocol):
"""
Interface for provider configuration cache operations
"""
def get(self) -> Optional[dict]:
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]) -> None:
"""Cache provider configuration"""
...
def delete(self) -> None:
"""Delete cached provider configuration"""
...
class ProviderConfigEncrypter:
tenant_id: str
config: list[BasicProviderConfig]
provider_config_cache: ProviderConfigCache
def __init__(
self,
tenant_id: str,
config: list[BasicProviderConfig],
provider_config_cache: ProviderConfigCache,
):
self.tenant_id = tenant_id
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cached_credentials = self.provider_config_cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass
self.provider_config_cache.set(data)
return data
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
cache = SingletonProviderCredentialsCache(
tenant_id=tenant_id,
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
)
encrypt = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_config_cache=cache,
)
return encrypt, cache

View File

@@ -0,0 +1,187 @@
import base64
import hashlib
import logging
from collections.abc import Mapping
from typing import Any, Optional
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from pydantic import TypeAdapter
from configs import dify_config
logger = logging.getLogger(__name__)
class OAuthEncryptionError(Exception):
"""OAuth encryption/decryption specific error"""
pass
class SystemOAuthEncrypter:
"""
A simple OAuth parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt OAuth parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: Optional[str] = None):
"""
Initialize the OAuth encrypter.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Raises:
ValueError: If SECRET_KEY is not configured or empty
"""
secret_key = secret_key or dify_config.SECRET_KEY or ""
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters.
Args:
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns:
Base64-encoded encrypted string
Raises:
OAuthEncryptionError: If encryption fails
ValueError: If oauth_params is invalid
"""
try:
# Generate random IV (16 bytes)
iv = get_random_bytes(16)
# Create AES cipher (CBC mode)
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
combined = iv + encrypted_data
# Return base64 encoded string
return base64.b64encode(combined).decode()
except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Raises:
OAuthEncryptionError: If decryption fails
ValueError: If encrypted_data is invalid
"""
if not isinstance(encrypted_data, str):
raise ValueError("encrypted_data must be a string")
if not encrypted_data:
raise ValueError("encrypted_data cannot be empty")
try:
# Base64 decode
combined = base64.b64decode(encrypted_data)
# Check minimum length (IV + at least one AES block)
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
raise ValueError("Invalid encrypted data format")
# Separate IV and encrypted data
iv = combined[:16]
encrypted_data_bytes = combined[16:]
# Create AES cipher
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Decrypt data
decrypted_data = cipher.decrypt(encrypted_data_bytes)
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(oauth_params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params
except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter:
"""
Create an OAuth encrypter instance.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns:
SystemOAuthEncrypter instance
"""
return SystemOAuthEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility)
_oauth_encrypter: Optional[SystemOAuthEncrypter] = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
"""
Get the global OAuth encrypter instance.
Returns:
SystemOAuthEncrypter instance
"""
global _oauth_encrypter
if _oauth_encrypter is None:
_oauth_encrypter = SystemOAuthEncrypter()
return _oauth_encrypter
# Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters using the global encrypter.
Args:
oauth_params: OAuth parameters dictionary
Returns:
Base64-encoded encrypted string
"""
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters using the global encrypter.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
"""
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)

View File

@@ -1,7 +1,9 @@
import uuid
def is_valid_uuid(uuid_str: str) -> bool:
def is_valid_uuid(uuid_str: str | None) -> bool:
if uuid_str is None or len(uuid_str) == 0:
return False
try:
uuid.UUID(uuid_str)
return True