feat: oauth refresh token (#22744)
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -739,7 +739,7 @@ class ToolOAuthCallback(Resource):
|
|||||||
raise Forbidden("no oauth available client config found for this tool provider")
|
raise Forbidden("no oauth available client config found for this tool provider")
|
||||||
|
|
||||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
||||||
credentials = oauth_handler.get_credentials(
|
credentials_response = oauth_handler.get_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
@@ -747,7 +747,10 @@ class ToolOAuthCallback(Resource):
|
|||||||
redirect_uri=redirect_uri,
|
redirect_uri=redirect_uri,
|
||||||
system_credentials=oauth_client_params,
|
system_credentials=oauth_client_params,
|
||||||
request=request,
|
request=request,
|
||||||
).credentials
|
)
|
||||||
|
|
||||||
|
credentials = credentials_response.credentials
|
||||||
|
expires_at = credentials_response.expires_at
|
||||||
|
|
||||||
if not credentials:
|
if not credentials:
|
||||||
raise Exception("the plugin credentials failed")
|
raise Exception("the plugin credentials failed")
|
||||||
@@ -758,6 +761,7 @@ class ToolOAuthCallback(Resource):
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
credentials=dict(credentials),
|
credentials=dict(credentials),
|
||||||
|
expires_at=expires_at,
|
||||||
api_type=CredentialType.OAUTH2,
|
api_type=CredentialType.OAUTH2,
|
||||||
)
|
)
|
||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||||
|
@@ -182,6 +182,10 @@ class PluginOAuthAuthorizationUrlResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class PluginOAuthCredentialsResponse(BaseModel):
|
class PluginOAuthCredentialsResponse(BaseModel):
|
||||||
|
metadata: Mapping[str, Any] = Field(
|
||||||
|
default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc."
|
||||||
|
)
|
||||||
|
expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.")
|
||||||
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")
|
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")
|
||||||
|
|
||||||
|
|
||||||
|
@@ -84,6 +84,41 @@ class OAuthHandler(BasePluginClient):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error getting credentials: {e}")
|
raise ValueError(f"Error getting credentials: {e}")
|
||||||
|
|
||||||
|
def refresh_credentials(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
plugin_id: str,
|
||||||
|
provider: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
system_credentials: Mapping[str, Any],
|
||||||
|
credentials: Mapping[str, Any],
|
||||||
|
) -> PluginOAuthCredentialsResponse:
|
||||||
|
try:
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials",
|
||||||
|
PluginOAuthCredentialsResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": provider,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"system_credentials": system_credentials,
|
||||||
|
"credentials": credentials,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for resp in response:
|
||||||
|
return resp
|
||||||
|
raise ValueError("No response received from plugin daemon for refresh credentials request.")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error refreshing credentials: {e}")
|
||||||
|
|
||||||
def _convert_request_to_raw_data(self, request: Request) -> bytes:
|
def _convert_request_to_raw_data(self, request: Request) -> bytes:
|
||||||
"""
|
"""
|
||||||
Convert a Request object to raw HTTP data.
|
Convert a Request object to raw HTTP data.
|
||||||
|
@@ -1,16 +1,19 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from collections.abc import Generator
|
import time
|
||||||
|
from collections.abc import Generator, Mapping
|
||||||
from os import listdir, path
|
from os import listdir, path
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||||
from core.plugin.entities.plugin import ToolProviderID
|
from core.plugin.entities.plugin import ToolProviderID
|
||||||
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from core.plugin.impl.tool import PluginToolManager
|
from core.plugin.impl.tool import PluginToolManager
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
@@ -244,12 +247,47 @@ class ToolManager:
|
|||||||
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# decrypt the credentials
|
||||||
|
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
|
||||||
|
|
||||||
|
# check if the credentials is expired
|
||||||
|
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
||||||
|
# TODO: circular import
|
||||||
|
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||||
|
|
||||||
|
# refresh the credentials
|
||||||
|
tool_provider = ToolProviderID(provider_id)
|
||||||
|
provider_name = tool_provider.provider_name
|
||||||
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||||
|
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
# refresh the credentials
|
||||||
|
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=builtin_provider.user_id,
|
||||||
|
plugin_id=tool_provider.plugin_id,
|
||||||
|
provider=provider_name,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
system_credentials=system_credentials or {},
|
||||||
|
credentials=decrypted_credentials,
|
||||||
|
)
|
||||||
|
# update the credentials
|
||||||
|
builtin_provider.encrypted_credentials = (
|
||||||
|
TypeAdapter(dict[str, Any])
|
||||||
|
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
|
||||||
|
.decode("utf-8")
|
||||||
|
)
|
||||||
|
builtin_provider.expires_at = refreshed_credentials.expires_at
|
||||||
|
db.session.commit()
|
||||||
|
decrypted_credentials = refreshed_credentials.credentials
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
builtin_tool.fork_tool_runtime(
|
builtin_tool.fork_tool_runtime(
|
||||||
runtime=ToolRuntime(
|
runtime=ToolRuntime(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
credentials=encrypter.decrypt(builtin_provider.credentials),
|
credentials=dict(decrypted_credentials),
|
||||||
credential_type=CredentialType.of(builtin_provider.credential_type),
|
credential_type=CredentialType.of(builtin_provider.credential_type),
|
||||||
runtime_parameters={},
|
runtime_parameters={},
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
|
@@ -0,0 +1,34 @@
|
|||||||
|
"""oauth_refresh_token
|
||||||
|
|
||||||
|
Revision ID: 375fe79ead14
|
||||||
|
Revises: 1a83934ad6d1
|
||||||
|
Create Date: 2025-07-22 00:19:45.599636
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '375fe79ead14'
|
||||||
|
down_revision = '1a83934ad6d1'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('-1'), nullable=False))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
|
||||||
|
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('expires_at')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
@@ -93,6 +93,7 @@ class BuiltinToolProvider(Base):
|
|||||||
credential_type: Mapped[str] = mapped_column(
|
credential_type: Mapped[str] = mapped_column(
|
||||||
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
|
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
|
||||||
)
|
)
|
||||||
|
expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def credentials(self) -> dict:
|
def credentials(self) -> dict:
|
||||||
|
@@ -38,6 +38,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class BuiltinToolManageService:
|
class BuiltinToolManageService:
|
||||||
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
|
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
|
||||||
|
__DEFAULT_EXPIRES_AT__ = 2147483647
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
|
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
|
||||||
@@ -212,6 +213,7 @@ class BuiltinToolManageService:
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
|
expires_at: int = -1,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -269,6 +271,9 @@ class BuiltinToolManageService:
|
|||||||
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
||||||
credential_type=api_type.value,
|
credential_type=api_type.value,
|
||||||
name=name,
|
name=name,
|
||||||
|
expires_at=expires_at
|
||||||
|
if expires_at is not None
|
||||||
|
else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__,
|
||||||
)
|
)
|
||||||
|
|
||||||
session.add(db_provider)
|
session.add(db_provider)
|
||||||
|
Reference in New Issue
Block a user