Support OAuth Integration for Plugin Tools (#22550)
Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -4,7 +4,7 @@ from openai import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
|
||||
|
||||
|
||||
class ToolRuntime(BaseModel):
|
||||
@@ -17,6 +17,7 @@ class ToolRuntime(BaseModel):
|
||||
invoke_from: Optional[InvokeFrom] = None
|
||||
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
credential_type: CredentialType = Field(default=CredentialType.API_KEY)
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
|
@@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
|
||||
from core.tools.entities.tool_entities import (
|
||||
CredentialType,
|
||||
OAuthSchema,
|
||||
ToolEntity,
|
||||
ToolProviderEntity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||
from core.tools.errors import (
|
||||
ToolProviderNotFoundError,
|
||||
@@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
|
||||
credentials_schema.append(credential_dict)
|
||||
|
||||
oauth_schema = None
|
||||
if provider_yaml.get("oauth_schema", None) is not None:
|
||||
oauth_schema = OAuthSchema(
|
||||
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
|
||||
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
entity=ToolProviderEntity(
|
||||
identity=provider_yaml["identity"],
|
||||
credentials_schema=credentials_schema,
|
||||
oauth_schema=oauth_schema,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -97,10 +111,39 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
if not self.entity.credentials_schema:
|
||||
return []
|
||||
return self.get_credentials_schema_by_type(CredentialType.API_KEY.value)
|
||||
|
||||
return self.entity.credentials_schema.copy()
|
||||
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:param credential_type: the type of the credential
|
||||
:return: the credentials schema of the provider
|
||||
"""
|
||||
if credential_type == CredentialType.OAUTH2.value:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY.value:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the oauth client schema of the provider
|
||||
|
||||
:return: the oauth client schema
|
||||
"""
|
||||
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
|
||||
|
||||
def get_supported_credential_types(self) -> list[str]:
|
||||
"""
|
||||
returns the credential support type of the provider
|
||||
"""
|
||||
types = []
|
||||
if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0:
|
||||
types.append(CredentialType.API_KEY.value)
|
||||
if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0:
|
||||
types.append(CredentialType.OAUTH2.value)
|
||||
return types
|
||||
|
||||
def get_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
@@ -123,7 +166,11 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
|
||||
:return: whether the provider needs credentials
|
||||
"""
|
||||
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
|
||||
return (
|
||||
self.entity.credentials_schema is not None
|
||||
and len(self.entity.credentials_schema) != 0
|
||||
or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0)
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
|
@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolProviderType
|
||||
|
||||
|
||||
class ToolApiEntity(BaseModel):
|
||||
@@ -87,3 +87,22 @@ class ToolProviderApiEntity(BaseModel):
|
||||
def optional_field(self, key: str, value: Any) -> dict:
|
||||
"""Return dict with key-value if value is truthy, empty dict otherwise."""
|
||||
return {key: value} if value else {}
|
||||
|
||||
|
||||
class ToolProviderCredentialApiEntity(BaseModel):
|
||||
id: str = Field(description="The unique id of the credential")
|
||||
name: str = Field(description="The name of the credential")
|
||||
provider: str = Field(description="The provider of the credential")
|
||||
credential_type: CredentialType = Field(description="The type of the credential")
|
||||
is_default: bool = Field(
|
||||
default=False, description="Whether the credential is the default credential for the provider in the workspace"
|
||||
)
|
||||
credentials: dict = Field(description="The credentials of the provider")
|
||||
|
||||
|
||||
class ToolProviderCredentialInfoApiEntity(BaseModel):
|
||||
supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
|
||||
is_oauth_custom_client_enabled: bool = Field(
|
||||
default=False, description="Whether the OAuth custom client is enabled for the provider"
|
||||
)
|
||||
credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")
|
||||
|
@@ -370,10 +370,18 @@ class ToolEntity(BaseModel):
|
||||
return v or []
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth credentials"
|
||||
)
|
||||
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
plugin_id: Optional[str] = None
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
oauth_schema: Optional[OAuthSchema] = None
|
||||
|
||||
|
||||
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
||||
@@ -453,6 +461,7 @@ class ToolSelector(BaseModel):
|
||||
options: Optional[list[PluginParameterOption]] = None
|
||||
|
||||
provider_id: str = Field(..., description="The id of the provider")
|
||||
credential_id: Optional[str] = Field(default=None, description="The id of the credential")
|
||||
tool_name: str = Field(..., description="The name of the tool")
|
||||
tool_description: str = Field(..., description="The description of the tool")
|
||||
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
|
||||
@@ -460,3 +469,36 @@ class ToolSelector(BaseModel):
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class CredentialType(enum.StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
return "API KEY"
|
||||
elif self == CredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
else:
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
def is_validate_allowed(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [item.value for item in cls]
|
||||
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name == "api-key":
|
||||
return cls.API_KEY
|
||||
elif type_name == "oauth2":
|
||||
return cls.OAUTH2
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
@@ -44,6 +44,7 @@ class PluginTool(Tool):
|
||||
tool_provider=self.entity.identity.provider,
|
||||
tool_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
credential_type=self.runtime.credential_type,
|
||||
tool_parameters=tool_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
|
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@@ -17,6 +18,7 @@ from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
@@ -24,7 +26,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@@ -41,16 +42,17 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
CredentialType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import (
|
||||
ProviderConfigEncrypter,
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
@@ -68,8 +70,11 @@ class ToolManager:
|
||||
@classmethod
|
||||
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||
"""
|
||||
|
||||
get the hardcoded provider
|
||||
|
||||
"""
|
||||
|
||||
if len(cls._hardcoded_providers) == 0:
|
||||
# init the builtin providers
|
||||
cls.load_hardcoded_providers_cache()
|
||||
@@ -113,7 +118,12 @@ class ToolManager:
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(Lock())
|
||||
|
||||
plugin_tool_providers = contexts.plugin_tool_providers.get()
|
||||
if provider in plugin_tool_providers:
|
||||
return plugin_tool_providers[provider]
|
||||
|
||||
with contexts.plugin_tool_providers_lock.get():
|
||||
# double check
|
||||
plugin_tool_providers = contexts.plugin_tool_providers.get()
|
||||
if provider in plugin_tool_providers:
|
||||
return plugin_tool_providers[provider]
|
||||
@@ -131,25 +141,7 @@ class ToolManager:
|
||||
)
|
||||
|
||||
plugin_tool_providers[provider] = controller
|
||||
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
|
||||
"""
|
||||
get the builtin tool
|
||||
|
||||
:param provider: the name of the provider
|
||||
:param tool_name: the name of the tool
|
||||
:param tenant_id: the id of the tenant
|
||||
:return: the provider, the tool
|
||||
"""
|
||||
provider_controller = cls.get_builtin_provider(provider, tenant_id)
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
|
||||
return tool
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def get_tool_runtime(
|
||||
@@ -160,6 +152,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||
credential_id: Optional[str] = None,
|
||||
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
@@ -170,6 +163,7 @@ class ToolManager:
|
||||
:param tenant_id: the tenant id
|
||||
:param invoke_from: invoke from
|
||||
:param tool_invoke_from: the tool invoke from
|
||||
:param credential_id: the credential id
|
||||
|
||||
:return: the tool
|
||||
"""
|
||||
@@ -193,49 +187,70 @@ class ToolManager:
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
builtin_provider = None
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
provider_id_entity = ToolProviderID(provider_id)
|
||||
# get credentials
|
||||
builtin_provider: BuiltinToolProvider | None = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
# get specific credentials
|
||||
if is_valid_uuid(credential_id):
|
||||
try:
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
except Exception as e:
|
||||
builtin_provider = None
|
||||
logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True)
|
||||
# if the provider has been deleted, raise an error
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
|
||||
|
||||
# fallback to the default provider
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
# use the default provider
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||
else:
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
|
||||
# decrypt the credentials
|
||||
credentials = builtin_provider.credentials
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
|
||||
],
|
||||
cache=ToolProviderCredentialsCache(
|
||||
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||
),
|
||||
)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=decrypted_credentials,
|
||||
credentials=encrypter.decrypt(builtin_provider.credentials),
|
||||
credential_type=CredentialType.of(builtin_provider.credential_type),
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
@@ -245,22 +260,16 @@ class ToolManager:
|
||||
|
||||
elif provider_type == ToolProviderType.API:
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
|
||||
# decrypt the credentials
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
|
||||
provider_type=api_provider.provider_type.value,
|
||||
provider_identity=api_provider.entity.identity.name,
|
||||
controller=api_provider,
|
||||
)
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
|
||||
return cast(
|
||||
ApiTool,
|
||||
api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=decrypted_credentials,
|
||||
credentials=encrypter.decrypt(credentials),
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
@@ -320,6 +329,7 @@ class ToolManager:
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
credential_id=agent_tool.credential_id,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
@@ -362,6 +372,7 @@ class ToolManager:
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
credential_id=workflow_tool.credential_id,
|
||||
)
|
||||
|
||||
parameters = tool_runtime.get_merged_runtime_parameters()
|
||||
@@ -391,6 +402,7 @@ class ToolManager:
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
credential_id: Optional[str] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get tool runtime from plugin
|
||||
@@ -402,6 +414,7 @@ class ToolManager:
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.PLUGIN,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
@@ -551,6 +564,22 @@ class ToolManager:
|
||||
|
||||
return cls._builtin_tools_labels[tool_name]
|
||||
|
||||
@classmethod
|
||||
def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
|
||||
"""
|
||||
list all the builtin providers
|
||||
"""
|
||||
# according to multi credentials, select the one with is_default=True first, then created_at oldest
|
||||
# for compatibility with old version
|
||||
sql = """
|
||||
SELECT DISTINCT ON (tenant_id, provider) id
|
||||
FROM tool_builtin_providers
|
||||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||
"""
|
||||
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
|
||||
|
||||
@classmethod
|
||||
def list_providers_from_api(
|
||||
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
|
||||
@@ -565,21 +594,13 @@ class ToolManager:
|
||||
|
||||
with db.session.no_autoflush:
|
||||
if "builtin" in filters:
|
||||
# get builtin providers
|
||||
builtin_providers = cls.list_builtin_providers(tenant_id)
|
||||
|
||||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
# rewrite db_builtin_providers
|
||||
for db_provider in db_builtin_providers:
|
||||
tool_provider_id = str(ToolProviderID(db_provider.provider))
|
||||
db_provider.provider = tool_provider_id
|
||||
|
||||
def find_db_builtin_provider(provider):
|
||||
return next((x for x in db_builtin_providers if x.provider == provider), None)
|
||||
# key: provider name, value: provider
|
||||
db_builtin_providers = {
|
||||
str(ToolProviderID(provider.provider)): provider
|
||||
for provider in cls.list_default_builtin_providers(tenant_id)
|
||||
}
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
@@ -591,10 +612,9 @@ class ToolManager:
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.entity.identity.name),
|
||||
db_provider=db_builtin_providers.get(provider.entity.identity.name),
|
||||
decrypt_credentials=False,
|
||||
)
|
||||
|
||||
@@ -604,7 +624,6 @@ class ToolManager:
|
||||
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
# get db api providers
|
||||
|
||||
if "api" in filters:
|
||||
db_api_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
@@ -764,15 +783,12 @@ class ToolManager:
|
||||
auth_type,
|
||||
)
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
||||
provider_type=controller.provider_type.value,
|
||||
provider_identity=controller.entity.identity.name,
|
||||
controller=controller,
|
||||
)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
|
||||
|
||||
try:
|
||||
icon = json.loads(provider_obj.icon)
|
||||
|
@@ -1,12 +1,8 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
@@ -14,110 +10,6 @@ from core.tools.entities.tool_entities import (
|
||||
)
|
||||
|
||||
|
||||
class ProviderConfigEncrypter(BaseModel):
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_type: str
|
||||
provider_identity: str
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
if use_cache:
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
try:
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if use_cache:
|
||||
cache.set(data)
|
||||
return data
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
|
||||
class ToolParameterConfigurationManager:
|
||||
"""
|
||||
Tool parameter configuration manager
|
||||
|
142
api/core/tools/utils/encryption.py
Normal file
142
api/core/tools/utils/encryption.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
||||
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
try:
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
|
||||
|
||||
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
||||
cache = SingletonProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=controller.provider_type.value,
|
||||
provider_identity=controller.entity.identity.name,
|
||||
)
|
||||
encrypt = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
||||
provider_config_cache=cache,
|
||||
)
|
||||
return encrypt, cache
|
187
api/core/tools/utils/system_oauth_encryption.py
Normal file
187
api/core/tools/utils/system_oauth_encryption.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Random import get_random_bytes
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthEncryptionError(Exception):
|
||||
"""OAuth encryption/decryption specific error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SystemOAuthEncrypter:
|
||||
"""
|
||||
A simple OAuth parameters encrypter using AES-CBC encryption.
|
||||
|
||||
This class provides methods to encrypt and decrypt OAuth parameters
|
||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the OAuth encrypter.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Raises:
|
||||
ValueError: If SECRET_KEY is not configured or empty
|
||||
"""
|
||||
secret_key = secret_key or dify_config.SECRET_KEY or ""
|
||||
|
||||
# Generate a fixed 256-bit key using SHA-256
|
||||
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||
|
||||
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If encryption fails
|
||||
ValueError: If oauth_params is invalid
|
||||
"""
|
||||
|
||||
try:
|
||||
# Generate random IV (16 bytes)
|
||||
iv = get_random_bytes(16)
|
||||
|
||||
# Create AES cipher (CBC mode)
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Encrypt data
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
|
||||
# Combine IV and encrypted data
|
||||
combined = iv + encrypted_data
|
||||
|
||||
# Return base64 encoded string
|
||||
return base64.b64encode(combined).decode()
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
|
||||
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If decryption fails
|
||||
ValueError: If encrypted_data is invalid
|
||||
"""
|
||||
if not isinstance(encrypted_data, str):
|
||||
raise ValueError("encrypted_data must be a string")
|
||||
|
||||
if not encrypted_data:
|
||||
raise ValueError("encrypted_data cannot be empty")
|
||||
|
||||
try:
|
||||
# Base64 decode
|
||||
combined = base64.b64decode(encrypted_data)
|
||||
|
||||
# Check minimum length (IV + at least one AES block)
|
||||
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
|
||||
raise ValueError("Invalid encrypted data format")
|
||||
|
||||
# Separate IV and encrypted data
|
||||
iv = combined[:16]
|
||||
encrypted_data_bytes = combined[16:]
|
||||
|
||||
# Create AES cipher
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Decrypt data
|
||||
decrypted_data = cipher.decrypt(encrypted_data_bytes)
|
||||
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||
|
||||
# Parse JSON
|
||||
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
|
||||
if not isinstance(oauth_params, dict):
|
||||
raise ValueError("Decrypted data is not a valid dictionary")
|
||||
|
||||
return oauth_params
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
|
||||
|
||||
# Factory function for creating encrypter instances
|
||||
def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter:
|
||||
"""
|
||||
Create an OAuth encrypter instance.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
"""
|
||||
return SystemOAuthEncrypter(secret_key=secret_key)
|
||||
|
||||
|
||||
# Global encrypter instance (for backward compatibility)
|
||||
_oauth_encrypter: Optional[SystemOAuthEncrypter] = None
|
||||
|
||||
|
||||
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
||||
"""
|
||||
Get the global OAuth encrypter instance.
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
"""
|
||||
global _oauth_encrypter
|
||||
if _oauth_encrypter is None:
|
||||
_oauth_encrypter = SystemOAuthEncrypter()
|
||||
return _oauth_encrypter
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
"""
|
||||
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
|
||||
|
||||
|
||||
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
"""
|
||||
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
|
@@ -1,7 +1,9 @@
|
||||
import uuid
|
||||
|
||||
|
||||
def is_valid_uuid(uuid_str: str) -> bool:
|
||||
def is_valid_uuid(uuid_str: str | None) -> bool:
|
||||
if uuid_str is None or len(uuid_str) == 0:
|
||||
return False
|
||||
try:
|
||||
uuid.UUID(uuid_str)
|
||||
return True
|
||||
|
Reference in New Issue
Block a user