From ad67094e54838fe7bbd23ef2b74ee6f1fbcd75c2 Mon Sep 17 00:00:00 2001 From: Maries Date: Wed, 23 Jul 2025 13:12:39 +0800 Subject: [PATCH] feat: oauth refresh token (#22744) Co-authored-by: Yeuoly --- .../console/workspace/tool_providers.py | 8 +++- api/core/plugin/entities/plugin_daemon.py | 4 ++ api/core/plugin/impl/oauth.py | 35 ++++++++++++++++ api/core/tools/tool_manager.py | 42 ++++++++++++++++++- ...2_0019-375fe79ead14_oauth_refresh_token.py | 34 +++++++++++++++ api/models/tools.py | 1 + .../tools/builtin_tools_manage_service.py | 5 +++ 7 files changed, 125 insertions(+), 4 deletions(-) create mode 100644 api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c70bf84d2..c4d1ef70d 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -739,7 +739,7 @@ class ToolOAuthCallback(Resource): raise Forbidden("no oauth available client config found for this tool provider") redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" - credentials = oauth_handler.get_credentials( + credentials_response = oauth_handler.get_credentials( tenant_id=tenant_id, user_id=user_id, plugin_id=plugin_id, @@ -747,7 +747,10 @@ class ToolOAuthCallback(Resource): redirect_uri=redirect_uri, system_credentials=oauth_client_params, request=request, - ).credentials + ) + + credentials = credentials_response.credentials + expires_at = credentials_response.expires_at if not credentials: raise Exception("the plugin credentials failed") @@ -758,6 +761,7 @@ class ToolOAuthCallback(Resource): tenant_id=tenant_id, provider=provider, credentials=dict(credentials), + expires_at=expires_at, api_type=CredentialType.OAUTH2, ) return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 00253b8a1..16ab66109 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -182,6 +182,10 @@ class PluginOAuthAuthorizationUrlResponse(BaseModel): class PluginOAuthCredentialsResponse(BaseModel): + metadata: Mapping[str, Any] = Field( + default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc." + ) + expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.") credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.") diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index d73e5d9f9..7f022992f 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -84,6 +84,41 @@ class OAuthHandler(BasePluginClient): except Exception as e: raise ValueError(f"Error getting credentials: {e}") + def refresh_credentials( + self, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + redirect_uri: str, + system_credentials: Mapping[str, Any], + credentials: Mapping[str, Any], + ) -> PluginOAuthCredentialsResponse: + try: + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials", + PluginOAuthCredentialsResponse, + data={ + "user_id": user_id, + "data": { + "provider": provider, + "redirect_uri": redirect_uri, + "system_credentials": system_credentials, + "credentials": credentials, + }, + }, + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for refresh credentials request.") + except Exception as e: + raise ValueError(f"Error refreshing credentials: {e}") + def _convert_request_to_raw_data(self, request: Request) -> bytes: """ Convert a Request object to raw HTTP data. diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 7822bc389..abbdf8de3 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -1,16 +1,19 @@ import json import logging import mimetypes -from collections.abc import Generator +import time +from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from pydantic import TypeAdapter from yarl import URL import contexts from core.helper.provider_cache import ToolProviderCredentialsCache from core.plugin.entities.plugin import ToolProviderID +from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime @@ -244,12 +247,47 @@ class ToolManager: tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id ), ) + + # decrypt the credentials + decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials) + + # check if the credentials is expired + if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()): + # TODO: circular import + from services.tools.builtin_tools_manage_service import BuiltinToolManageService + + # refresh the credentials + tool_provider = ToolProviderID(provider_id) + provider_name = tool_provider.provider_name + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback" + system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id) + oauth_handler = OAuthHandler() + # refresh the credentials + refreshed_credentials = oauth_handler.refresh_credentials( + tenant_id=tenant_id, + user_id=builtin_provider.user_id, + plugin_id=tool_provider.plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=system_credentials or {}, + credentials=decrypted_credentials, + ) + # update the credentials + builtin_provider.encrypted_credentials = ( + TypeAdapter(dict[str, Any]) + .dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials))) + .decode("utf-8") + ) + builtin_provider.expires_at = refreshed_credentials.expires_at + db.session.commit() + decrypted_credentials = refreshed_credentials.credentials + return cast( BuiltinTool, builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, - credentials=encrypter.decrypt(builtin_provider.credentials), + credentials=dict(decrypted_credentials), credential_type=CredentialType.of(builtin_provider.credential_type), runtime_parameters={}, invoke_from=invoke_from, diff --git a/api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py b/api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py new file mode 100644 index 000000000..76d0cb294 --- /dev/null +++ b/api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py @@ -0,0 +1,34 @@ +"""oauth_refresh_token + +Revision ID: 375fe79ead14 +Revises: 1a83934ad6d1 +Create Date: 2025-07-22 00:19:45.599636 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '375fe79ead14' +down_revision = '1a83934ad6d1' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('-1'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_column('expires_at') + + # ### end Alembic commands ### diff --git a/api/models/tools.py b/api/models/tools.py index a0b7e5417..8c91e91f0 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -93,6 +93,7 @@ class BuiltinToolProvider(Base): credential_type: Mapped[str] = mapped_column( db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") ) + expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1")) @property def credentials(self) -> dict: diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 430575b53..b8e3ce265 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -38,6 +38,7 @@ logger = logging.getLogger(__name__) class BuiltinToolManageService: __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 + __DEFAULT_EXPIRES_AT__ = 2147483647 @staticmethod def delete_custom_oauth_client_params(tenant_id: str, provider: str): @@ -212,6 +213,7 @@ class BuiltinToolManageService: tenant_id: str, provider: str, credentials: dict, + expires_at: int = -1, name: str | None = None, ): """ @@ -269,6 +271,9 @@ class BuiltinToolManageService: encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), credential_type=api_type.value, name=name, + expires_at=expires_at + if expires_at is not None + else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__, ) session.add(db_provider)