feat: oauth refresh token (#22744)

Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
Maries
2025-07-23 13:12:39 +08:00
committed by GitHub
parent 6d3e198c3c
commit ad67094e54
7 changed files with 125 additions and 4 deletions

View File

@@ -739,7 +739,7 @@ class ToolOAuthCallback(Resource):
raise Forbidden("no oauth available client config found for this tool provider") raise Forbidden("no oauth available client config found for this tool provider")
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials = oauth_handler.get_credentials( credentials_response = oauth_handler.get_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user_id, user_id=user_id,
plugin_id=plugin_id, plugin_id=plugin_id,
@@ -747,7 +747,10 @@ class ToolOAuthCallback(Resource):
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
system_credentials=oauth_client_params, system_credentials=oauth_client_params,
request=request, request=request,
).credentials )
credentials = credentials_response.credentials
expires_at = credentials_response.expires_at
if not credentials: if not credentials:
raise Exception("the plugin credentials failed") raise Exception("the plugin credentials failed")
@@ -758,6 +761,7 @@ class ToolOAuthCallback(Resource):
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
credentials=dict(credentials), credentials=dict(credentials),
expires_at=expires_at,
api_type=CredentialType.OAUTH2, api_type=CredentialType.OAUTH2,
) )
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

View File

@@ -182,6 +182,10 @@ class PluginOAuthAuthorizationUrlResponse(BaseModel):
class PluginOAuthCredentialsResponse(BaseModel): class PluginOAuthCredentialsResponse(BaseModel):
metadata: Mapping[str, Any] = Field(
default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc."
)
expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.")
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.") credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")

View File

@@ -84,6 +84,41 @@ class OAuthHandler(BasePluginClient):
except Exception as e: except Exception as e:
raise ValueError(f"Error getting credentials: {e}") raise ValueError(f"Error getting credentials: {e}")
def refresh_credentials(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any],
credentials: Mapping[str, Any],
) -> PluginOAuthCredentialsResponse:
try:
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials",
PluginOAuthCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for refresh credentials request.")
except Exception as e:
raise ValueError(f"Error refreshing credentials: {e}")
def _convert_request_to_raw_data(self, request: Request) -> bytes: def _convert_request_to_raw_data(self, request: Request) -> bytes:
""" """
Convert a Request object to raw HTTP data. Convert a Request object to raw HTTP data.

View File

@@ -1,16 +1,19 @@
import json import json
import logging import logging
import mimetypes import mimetypes
from collections.abc import Generator import time
from collections.abc import Generator, Mapping
from os import listdir, path from os import listdir, path
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from pydantic import TypeAdapter
from yarl import URL from yarl import URL
import contexts import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.plugin.impl.tool import PluginToolManager from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
@@ -244,12 +247,47 @@ class ToolManager:
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
), ),
) )
# decrypt the credentials
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
# check if the credentials is expired
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
# refresh the credentials
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
# refresh the credentials
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = (
TypeAdapter(dict[str, Any])
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
.decode("utf-8")
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials
return cast( return cast(
BuiltinTool, BuiltinTool,
builtin_tool.fork_tool_runtime( builtin_tool.fork_tool_runtime(
runtime=ToolRuntime( runtime=ToolRuntime(
tenant_id=tenant_id, tenant_id=tenant_id,
credentials=encrypter.decrypt(builtin_provider.credentials), credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type), credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={}, runtime_parameters={},
invoke_from=invoke_from, invoke_from=invoke_from,

View File

@@ -0,0 +1,34 @@
"""oauth_refresh_token
Revision ID: 375fe79ead14
Revises: 1a83934ad6d1
Create Date: 2025-07-22 00:19:45.599636
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '375fe79ead14'
down_revision = '1a83934ad6d1'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('-1'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_column('expires_at')
# ### end Alembic commands ###

View File

@@ -93,6 +93,7 @@ class BuiltinToolProvider(Base):
credential_type: Mapped[str] = mapped_column( credential_type: Mapped[str] = mapped_column(
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
) )
expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1"))
@property @property
def credentials(self) -> dict: def credentials(self) -> dict:

View File

@@ -38,6 +38,7 @@ logger = logging.getLogger(__name__)
class BuiltinToolManageService: class BuiltinToolManageService:
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
__DEFAULT_EXPIRES_AT__ = 2147483647
@staticmethod @staticmethod
def delete_custom_oauth_client_params(tenant_id: str, provider: str): def delete_custom_oauth_client_params(tenant_id: str, provider: str):
@@ -212,6 +213,7 @@ class BuiltinToolManageService:
tenant_id: str, tenant_id: str,
provider: str, provider: str,
credentials: dict, credentials: dict,
expires_at: int = -1,
name: str | None = None, name: str | None = None,
): ):
""" """
@@ -269,6 +271,9 @@ class BuiltinToolManageService:
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value, credential_type=api_type.value,
name=name, name=name,
expires_at=expires_at
if expires_at is not None
else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__,
) )
session.add(db_provider) session.add(db_provider)