feat: oauth provider (#24206)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: yessenia <yessenia.contact@gmail.com>
This commit is contained in:

committed by
GitHub

parent
3d5a4df9d0
commit
f32e176d6a
@@ -70,7 +70,7 @@ from .app import (
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
|
||||
|
||||
# Import billing controllers
|
||||
from .billing import billing, compliance
|
||||
|
189
api/controllers/console/auth/oauth_server.py
Normal file
189
api/controllers/console/auth/oauth_server.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from functools import wraps
|
||||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
|
||||
from .. import api
|
||||
|
||||
|
||||
def oauth_server_client_id_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
client_id = parsed_args.get("client_id")
|
||||
if not client_id:
|
||||
raise BadRequest("client_id is required")
|
||||
|
||||
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
|
||||
if not oauth_provider_app:
|
||||
raise NotFound("client_id is invalid")
|
||||
|
||||
kwargs["oauth_provider_app"] = oauth_provider_app
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def oauth_server_access_token_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
oauth_provider_app = kwargs.get("oauth_provider_app")
|
||||
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
if not request.headers.get("Authorization"):
|
||||
raise BadRequest("Authorization is required")
|
||||
|
||||
authorization_header = request.headers.get("Authorization")
|
||||
if not authorization_header:
|
||||
raise BadRequest("Authorization header is required")
|
||||
|
||||
parts = authorization_header.split(" ")
|
||||
if len(parts) != 2:
|
||||
raise BadRequest("Invalid Authorization header format")
|
||||
|
||||
token_type = parts[0]
|
||||
if token_type != "Bearer":
|
||||
raise BadRequest("token_type is invalid")
|
||||
|
||||
access_token = parts[1]
|
||||
if not access_token:
|
||||
raise BadRequest("access_token is required")
|
||||
|
||||
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
|
||||
if not account:
|
||||
raise BadRequest("access_token or client_id is invalid")
|
||||
|
||||
kwargs["account"] = account
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
class OAuthServerAppApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
redirect_uri = parsed_args.get("redirect_uri")
|
||||
|
||||
# check if redirect_uri is valid
|
||||
if redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"app_icon": oauth_provider_app.app_icon,
|
||||
"app_label": oauth_provider_app.app_label,
|
||||
"scope": oauth_provider_app.scope,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthServerUserAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
user_account_id = account.id
|
||||
|
||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"code": code,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthServerUserTokenApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("grant_type", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=False, location="json")
|
||||
parser.add_argument("client_secret", type=str, required=False, location="json")
|
||||
parser.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
parser.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not parsed_args["code"]:
|
||||
raise BadRequest("code is required")
|
||||
|
||||
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
|
||||
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||
if not parsed_args["refresh_token"]:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
|
||||
class OAuthServerUserAccountApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
@oauth_server_access_token_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp, account: Account):
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"name": account.name,
|
||||
"email": account.email,
|
||||
"avatar": account.avatar,
|
||||
"interface_language": account.interface_language,
|
||||
"timezone": account.timezone,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(OAuthServerAppApi, "/oauth/provider")
|
||||
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
|
||||
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
|
||||
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")
|
@@ -0,0 +1,45 @@
|
||||
"""empty message
|
||||
|
||||
Revision ID: 8d289573e1da
|
||||
Revises: fa8b0fa6f407
|
||||
Create Date: 2025-08-20 17:47:17.015695
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '8d289573e1da'
|
||||
down_revision = '0e154742a5fa'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('oauth_provider_apps',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('app_icon', sa.String(length=255), nullable=False),
|
||||
sa.Column('app_label', sa.JSON(), server_default='{}', nullable=False),
|
||||
sa.Column('client_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('client_secret', sa.String(length=255), nullable=False),
|
||||
sa.Column('redirect_uris', sa.JSON(), server_default='[]', nullable=False),
|
||||
sa.Column('scope', sa.String(length=255), server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='oauth_provider_app_pkey')
|
||||
)
|
||||
with op.batch_alter_table('oauth_provider_apps', schema=None) as batch_op:
|
||||
batch_op.create_index('oauth_provider_app_client_id_idx', ['client_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('oauth_provider_apps', schema=None) as batch_op:
|
||||
batch_op.drop_index('oauth_provider_app_client_id_idx')
|
||||
|
||||
op.drop_table('oauth_provider_apps')
|
||||
# ### end Alembic commands ###
|
@@ -580,6 +580,32 @@ class InstalledApp(Base):
|
||||
return tenant
|
||||
|
||||
|
||||
class OAuthProviderApp(Base):
|
||||
"""
|
||||
Globally shared OAuth provider app information.
|
||||
Only for Dify Cloud.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_provider_apps"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="oauth_provider_app_pkey"),
|
||||
sa.Index("oauth_provider_app_client_id_idx", "client_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_icon = mapped_column(String(255), nullable=False)
|
||||
app_label = mapped_column(sa.JSON, nullable=False, server_default="{}")
|
||||
client_id = mapped_column(String(255), nullable=False)
|
||||
client_secret = mapped_column(String(255), nullable=False)
|
||||
redirect_uris = mapped_column(sa.JSON, nullable=False, server_default="[]")
|
||||
scope = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"),
|
||||
)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
|
||||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
__table_args__ = (
|
||||
|
94
api/services/oauth_server.py
Normal file
94
api/services/oauth_server.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import enum
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
class OAuthGrantType(enum.StrEnum):
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY = "oauth_provider:{client_id}:authorization_code:{code}"
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:access_token:{token}"
|
||||
OAUTH_ACCESS_TOKEN_EXPIRES_IN = 60 * 60 * 12 # 12 hours
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:refresh_token:{token}"
|
||||
OAUTH_REFRESH_TOKEN_EXPIRES_IN = 60 * 60 * 24 * 30 # 30 days
|
||||
|
||||
|
||||
class OAuthServerService:
|
||||
@staticmethod
|
||||
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
|
||||
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
return session.execute(query).scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
def sign_oauth_authorization_code(client_id: str, user_account_id: str) -> str:
|
||||
code = str(uuid.uuid4())
|
||||
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
|
||||
redis_client.set(redis_key, user_account_id, ex=60 * 10) # 10 minutes
|
||||
return code
|
||||
|
||||
@staticmethod
|
||||
def sign_oauth_access_token(
|
||||
grant_type: OAuthGrantType,
|
||||
code: str = "",
|
||||
client_id: str = "",
|
||||
refresh_token: str = "",
|
||||
) -> tuple[str, str]:
|
||||
match grant_type:
|
||||
case OAuthGrantType.AUTHORIZATION_CODE:
|
||||
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
|
||||
user_account_id = redis_client.get(redis_key)
|
||||
if not user_account_id:
|
||||
raise BadRequest("invalid code")
|
||||
|
||||
# delete code
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
|
||||
refresh_token = OAuthServerService._sign_oauth_refresh_token(client_id, user_account_id)
|
||||
return access_token, refresh_token
|
||||
case OAuthGrantType.REFRESH_TOKEN:
|
||||
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=refresh_token)
|
||||
user_account_id = redis_client.get(redis_key)
|
||||
if not user_account_id:
|
||||
raise BadRequest("invalid refresh token")
|
||||
|
||||
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
|
||||
return access_token, refresh_token
|
||||
|
||||
@staticmethod
|
||||
def _sign_oauth_access_token(client_id: str, user_account_id: str) -> str:
|
||||
token = str(uuid.uuid4())
|
||||
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
|
||||
redis_client.set(redis_key, user_account_id, ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN)
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def _sign_oauth_refresh_token(client_id: str, user_account_id: str) -> str:
|
||||
token = str(uuid.uuid4())
|
||||
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
|
||||
redis_client.set(redis_key, user_account_id, ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN)
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def validate_oauth_access_token(client_id: str, token: str) -> Account | None:
|
||||
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
|
||||
user_account_id = redis_client.get(redis_key)
|
||||
if not user_account_id:
|
||||
return None
|
||||
|
||||
user_id_str = user_account_id.decode("utf-8")
|
||||
|
||||
return AccountService.load_user(user_id_str)
|
@@ -2,11 +2,11 @@
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiArrowRightUpLine, RiRobot2Line } from '@remixicon/react'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import Button from '../components/base/button'
|
||||
import Avatar from './avatar'
|
||||
import Button from '@/app/components/base/button'
|
||||
import DifyLogo from '@/app/components/base/logo/dify-logo'
|
||||
import { useCallback } from 'react'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import Avatar from './avatar'
|
||||
|
||||
const Header = () => {
|
||||
const { t } = useTranslation()
|
37
web/app/account/oauth/authorize/layout.tsx
Normal file
37
web/app/account/oauth/authorize/layout.tsx
Normal file
@@ -0,0 +1,37 @@
|
||||
'use client'
|
||||
import Header from '@/app/signin/_header'
|
||||
|
||||
import cn from '@/utils/classnames'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import { AppContextProvider } from '@/context/app-context'
|
||||
import { useMemo } from 'react'
|
||||
|
||||
export default function SignInLayout({ children }: any) {
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
useDocumentTitle('')
|
||||
const isLoggedIn = useMemo(() => {
|
||||
try {
|
||||
return Boolean(localStorage.getItem('console_token') && localStorage.getItem('refresh_token'))
|
||||
}
|
||||
catch { return false }
|
||||
}, [])
|
||||
return <>
|
||||
<div className={cn('flex min-h-screen w-full justify-center bg-background-default-burn p-6')}>
|
||||
<div className={cn('flex w-full shrink-0 flex-col items-center rounded-2xl border border-effects-highlight bg-background-default-subtle')}>
|
||||
<Header />
|
||||
<div className={cn('flex w-full grow flex-col items-center justify-center px-6 md:px-[108px]')}>
|
||||
<div className='flex flex-col md:w-[400px]'>
|
||||
{isLoggedIn ? <AppContextProvider>
|
||||
{children}
|
||||
</AppContextProvider>
|
||||
: children}
|
||||
</div>
|
||||
</div>
|
||||
{systemFeatures.branding.enabled === false && <div className='system-xs-regular px-8 py-6 text-text-tertiary'>
|
||||
© {new Date().getFullYear()} LangGenius, Inc. All rights reserved.
|
||||
</div>}
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
}
|
205
web/app/account/oauth/authorize/page.tsx
Normal file
205
web/app/account/oauth/authorize/page.tsx
Normal file
@@ -0,0 +1,205 @@
|
||||
'use client'
|
||||
|
||||
import React, { useEffect, useMemo, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth'
|
||||
import {
|
||||
RiAccountCircleLine,
|
||||
RiGlobalLine,
|
||||
RiInfoCardLine,
|
||||
RiMailLine,
|
||||
RiTranslate2,
|
||||
} from '@remixicon/react'
|
||||
import dayjs from 'dayjs'
|
||||
|
||||
export const OAUTH_AUTHORIZE_PENDING_KEY = 'oauth_authorize_pending'
|
||||
export const REDIRECT_URL_KEY = 'oauth_redirect_url'
|
||||
|
||||
const OAUTH_AUTHORIZE_PENDING_TTL = 60 * 3
|
||||
|
||||
function setItemWithExpiry(key: string, value: string, ttl: number) {
|
||||
const item = {
|
||||
value,
|
||||
expiry: dayjs().add(ttl, 'seconds').unix(),
|
||||
}
|
||||
localStorage.setItem(key, JSON.stringify(item))
|
||||
}
|
||||
|
||||
function buildReturnUrl(pathname: string, search: string) {
|
||||
try {
|
||||
const base = `${globalThis.location.origin}${pathname}${search}`
|
||||
return base
|
||||
}
|
||||
catch {
|
||||
return pathname + search
|
||||
}
|
||||
}
|
||||
|
||||
export default function OAuthAuthorize() {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const SCOPE_INFO_MAP: Record<string, { icon: React.ComponentType<{ className?: string }>, label: string }> = {
|
||||
'read:name': {
|
||||
icon: RiInfoCardLine,
|
||||
label: t('oauth.scopes.name'),
|
||||
},
|
||||
'read:email': {
|
||||
icon: RiMailLine,
|
||||
label: t('oauth.scopes.email'),
|
||||
},
|
||||
'read:avatar': {
|
||||
icon: RiAccountCircleLine,
|
||||
label: t('oauth.scopes.avatar'),
|
||||
},
|
||||
'read:interface_language': {
|
||||
icon: RiTranslate2,
|
||||
label: t('oauth.scopes.languagePreference'),
|
||||
},
|
||||
'read:timezone': {
|
||||
icon: RiGlobalLine,
|
||||
label: t('oauth.scopes.timezone'),
|
||||
},
|
||||
}
|
||||
|
||||
const router = useRouter()
|
||||
const language = useLanguage()
|
||||
const searchParams = useSearchParams()
|
||||
const client_id = decodeURIComponent(searchParams.get('client_id') || '')
|
||||
const redirect_uri = decodeURIComponent(searchParams.get('redirect_uri') || '')
|
||||
const { userProfile } = useAppContext()
|
||||
const { data: authAppInfo, isLoading, isError } = useOAuthAppInfo(client_id, redirect_uri)
|
||||
const { mutateAsync: authorize, isPending: authorizing } = useAuthorizeOAuthApp()
|
||||
const hasNotifiedRef = useRef(false)
|
||||
|
||||
const isLoggedIn = useMemo(() => {
|
||||
try {
|
||||
return Boolean(localStorage.getItem('console_token') && localStorage.getItem('refresh_token'))
|
||||
}
|
||||
catch { return false }
|
||||
}, [])
|
||||
|
||||
const onLoginSwitchClick = () => {
|
||||
try {
|
||||
const returnUrl = buildReturnUrl('/account/oauth/authorize', `?client_id=${encodeURIComponent(client_id)}&redirect_uri=${encodeURIComponent(redirect_uri)}`)
|
||||
setItemWithExpiry(OAUTH_AUTHORIZE_PENDING_KEY, returnUrl, OAUTH_AUTHORIZE_PENDING_TTL)
|
||||
router.push(`/signin?${REDIRECT_URL_KEY}=${encodeURIComponent(returnUrl)}`)
|
||||
}
|
||||
catch {
|
||||
router.push('/signin')
|
||||
}
|
||||
}
|
||||
|
||||
const onAuthorize = async () => {
|
||||
if (!client_id || !redirect_uri)
|
||||
return
|
||||
try {
|
||||
const { code } = await authorize({ client_id })
|
||||
const url = new URL(redirect_uri)
|
||||
url.searchParams.set('code', code)
|
||||
globalThis.location.href = url.toString()
|
||||
}
|
||||
catch (err: any) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: `${t('oauth.error.authorizeFailed')}: ${err.message}`,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
const invalidParams = !client_id || !redirect_uri
|
||||
if ((invalidParams || isError) && !hasNotifiedRef.current) {
|
||||
hasNotifiedRef.current = true
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: invalidParams ? t('oauth.error.invalidParams') : t('oauth.error.authAppInfoFetchFailed'),
|
||||
duration: 0,
|
||||
})
|
||||
}
|
||||
}, [client_id, redirect_uri, isError])
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className='bg-background-default-subtle'>
|
||||
<Loading type='app' />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className='bg-background-default-subtle'>
|
||||
{authAppInfo?.app_icon && (
|
||||
<div className='w-max rounded-2xl border-[0.5px] border-components-panel-border bg-text-primary-on-surface p-3 shadow-lg'>
|
||||
<img src={authAppInfo.app_icon} alt='app icon' className='h-10 w-10 rounded' />
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className={`mb-4 mt-5 flex flex-col gap-2 ${isLoggedIn ? 'pb-2' : ''}`}>
|
||||
<div className='title-4xl-semi-bold'>
|
||||
{isLoggedIn && <div className='text-text-primary'>{t('oauth.connect')}</div>}
|
||||
<div className='text-[var(--color-saas-dify-blue-inverted)]'>{authAppInfo?.app_label[language] || authAppInfo?.app_label?.en_US || t('oauth.unknownApp')}</div>
|
||||
{!isLoggedIn && <div className='text-text-primary'>{t('oauth.tips.notLoggedIn')}</div>}
|
||||
</div>
|
||||
<div className='body-md-regular text-text-secondary'>{isLoggedIn ? `${authAppInfo?.app_label[language] || authAppInfo?.app_label?.en_US || t('oauth.unknownApp')} ${t('oauth.tips.loggedIn')}` : t('oauth.tips.needLogin')}</div>
|
||||
</div>
|
||||
|
||||
{isLoggedIn && userProfile && (
|
||||
<div className='flex items-center justify-between rounded-xl bg-background-section-burn-inverted p-3'>
|
||||
<div className='flex items-center gap-2.5'>
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size={36} />
|
||||
<div>
|
||||
<div className='system-md-semi-bold text-text-secondary'>{userProfile.name}</div>
|
||||
<div className='system-xs-regular text-text-tertiary'>{userProfile.email}</div>
|
||||
</div>
|
||||
</div>
|
||||
<Button variant='tertiary' size='small' onClick={onLoginSwitchClick}>{t('oauth.switchAccount')}</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isLoggedIn && Boolean(authAppInfo?.scope) && (
|
||||
<div className='mt-2 flex flex-col gap-2.5 rounded-xl bg-background-section-burn-inverted px-[22px] py-5 text-text-secondary'>
|
||||
{authAppInfo!.scope.split(/\s+/).filter(Boolean).map((scope: string) => {
|
||||
const Icon = SCOPE_INFO_MAP[scope]
|
||||
return (
|
||||
<div key={scope} className='body-sm-medium flex items-center gap-2 text-text-secondary'>
|
||||
{Icon ? <Icon.icon className='h-4 w-4' /> : <RiAccountCircleLine className='h-4 w-4' />}
|
||||
{Icon.label}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className='flex flex-col items-center gap-2 pt-4'>
|
||||
{!isLoggedIn ? (
|
||||
<Button variant='primary' size='large' className='w-full' onClick={onLoginSwitchClick}>{t('oauth.login')}</Button>
|
||||
) : (
|
||||
<>
|
||||
<Button variant='primary' size='large' className='w-full' onClick={onAuthorize} disabled={!client_id || !redirect_uri || isError || authorizing} loading={authorizing}>{t('oauth.continue')}</Button>
|
||||
<Button size='large' className='w-full' onClick={() => router.push('/apps')}>{t('common.operation.cancel')}</Button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
<div className='mt-4 py-2'>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="400" height="1" viewBox="0 0 400 1" fill="none">
|
||||
<path d="M0 0.5H400" stroke="url(#paint0_linear_2_5904)" />
|
||||
<defs>
|
||||
<linearGradient id="paint0_linear_2_5904" x1="400" y1="9.49584" x2="0.000228929" y2="9.17666" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="white" stop-opacity="0.01" />
|
||||
<stop offset="0.505" stop-color="#101828" stop-opacity="0.08" />
|
||||
<stop offset="1" stop-color="white" stop-opacity="0.01" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
</div>
|
||||
<div className='system-xs-regular mt-3 text-text-tertiary'>{t('oauth.tips.common')}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
@@ -56,8 +56,7 @@ const Toast = ({
|
||||
'top-0',
|
||||
'right-0',
|
||||
)}>
|
||||
<div className={`absolute inset-0 -z-10 opacity-40 ${
|
||||
(type === 'success' && 'bg-toast-success-bg')
|
||||
<div className={`absolute inset-0 -z-10 opacity-40 ${(type === 'success' && 'bg-toast-success-bg')
|
||||
|| (type === 'warning' && 'bg-toast-warning-bg')
|
||||
|| (type === 'error' && 'bg-toast-error-bg')
|
||||
|| (type === 'info' && 'bg-toast-info-bg')
|
||||
@@ -162,7 +161,9 @@ Toast.notify = ({
|
||||
</ToastContext.Provider>,
|
||||
)
|
||||
document.body.appendChild(holder)
|
||||
setTimeout(toastHandler.clear, duration || defaultDuring)
|
||||
const d = duration ?? defaultDuring
|
||||
if (d > 0)
|
||||
setTimeout(toastHandler.clear, d)
|
||||
}
|
||||
|
||||
return toastHandler
|
||||
|
@@ -9,6 +9,7 @@ import {
|
||||
EDUCATION_VERIFYING_LOCALSTORAGE_ITEM,
|
||||
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
|
||||
} from '@/app/education-apply/constants'
|
||||
import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect'
|
||||
|
||||
type SwrInitializerProps = {
|
||||
children: ReactNode
|
||||
@@ -63,6 +64,10 @@ const SwrInitializer = ({
|
||||
if (searchParams.has('access_token') || searchParams.has('refresh_token')) {
|
||||
consoleToken && localStorage.setItem('console_token', consoleToken)
|
||||
refreshToken && localStorage.setItem('refresh_token', refreshToken)
|
||||
const redirectUrl = resolvePostLoginRedirect(searchParams)
|
||||
if (redirectUrl)
|
||||
location.replace(redirectUrl)
|
||||
else
|
||||
router.replace(pathname)
|
||||
}
|
||||
|
||||
|
@@ -10,6 +10,7 @@ import Input from '@/app/components/base/input'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { emailLoginWithCode, sendEMailLoginCode } from '@/service/common'
|
||||
import I18NContext from '@/context/i18n'
|
||||
import { resolvePostLoginRedirect } from '../utils/post-login-redirect'
|
||||
|
||||
export default function CheckCode() {
|
||||
const { t } = useTranslation()
|
||||
@@ -43,7 +44,13 @@ export default function CheckCode() {
|
||||
if (ret.result === 'success') {
|
||||
localStorage.setItem('console_token', ret.data.access_token)
|
||||
localStorage.setItem('refresh_token', ret.data.refresh_token)
|
||||
router.replace(invite_token ? `/signin/invite-settings?${searchParams.toString()}` : '/apps')
|
||||
if (invite_token) {
|
||||
router.replace(`/signin/invite-settings?${searchParams.toString()}`)
|
||||
}
|
||||
else {
|
||||
const redirectUrl = resolvePostLoginRedirect(searchParams)
|
||||
router.replace(redirectUrl || '/apps')
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (error) { console.error(error) }
|
||||
|
@@ -10,6 +10,7 @@ import { login } from '@/service/common'
|
||||
import Input from '@/app/components/base/input'
|
||||
import I18NContext from '@/context/i18n'
|
||||
import { noop } from 'lodash-es'
|
||||
import { resolvePostLoginRedirect } from '../utils/post-login-redirect'
|
||||
|
||||
type MailAndPasswordAuthProps = {
|
||||
isInvite: boolean
|
||||
@@ -74,7 +75,8 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis
|
||||
else {
|
||||
localStorage.setItem('console_token', res.data.access_token)
|
||||
localStorage.setItem('refresh_token', res.data.refresh_token)
|
||||
router.replace('/apps')
|
||||
const redirectUrl = resolvePostLoginRedirect(searchParams)
|
||||
router.replace(redirectUrl || '/apps')
|
||||
}
|
||||
}
|
||||
else if (res.code === 'account_not_found') {
|
||||
|
@@ -18,6 +18,7 @@ import Loading from '@/app/components/base/loading'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { noop } from 'lodash-es'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { resolvePostLoginRedirect } from '../utils/post-login-redirect'
|
||||
|
||||
export default function InviteSettingsPage() {
|
||||
const { t } = useTranslation()
|
||||
@@ -60,7 +61,8 @@ export default function InviteSettingsPage() {
|
||||
localStorage.setItem('console_token', res.data.access_token)
|
||||
localStorage.setItem('refresh_token', res.data.refresh_token)
|
||||
await setLocaleOnClient(language, false)
|
||||
router.replace('/apps')
|
||||
const redirectUrl = resolvePostLoginRedirect(searchParams)
|
||||
router.replace(redirectUrl || '/apps')
|
||||
}
|
||||
}
|
||||
catch {
|
||||
|
@@ -10,7 +10,7 @@ export default function SignInLayout({ children }: any) {
|
||||
useDocumentTitle('')
|
||||
return <>
|
||||
<div className={cn('flex min-h-screen w-full justify-center bg-background-default-burn p-6')}>
|
||||
<div className={cn('flex w-full shrink-0 flex-col rounded-2xl border border-effects-highlight bg-background-default-subtle')}>
|
||||
<div className={cn('flex w-full shrink-0 flex-col items-center rounded-2xl border border-effects-highlight bg-background-default-subtle')}>
|
||||
<Header />
|
||||
<div className={cn('flex w-full grow flex-col items-center justify-center px-6 md:px-[108px]')}>
|
||||
<div className='flex flex-col md:w-[400px]'>
|
||||
|
@@ -14,6 +14,7 @@ import { LicenseStatus } from '@/types/feature'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { resolvePostLoginRedirect } from './utils/post-login-redirect'
|
||||
|
||||
const NormalForm = () => {
|
||||
const { t } = useTranslation()
|
||||
@@ -37,7 +38,8 @@ const NormalForm = () => {
|
||||
if (consoleToken && refreshToken) {
|
||||
localStorage.setItem('console_token', consoleToken)
|
||||
localStorage.setItem('refresh_token', refreshToken)
|
||||
router.replace('/apps')
|
||||
const redirectUrl = resolvePostLoginRedirect(searchParams)
|
||||
router.replace(redirectUrl || '/apps')
|
||||
return
|
||||
}
|
||||
|
||||
|
36
web/app/signin/utils/post-login-redirect.ts
Normal file
36
web/app/signin/utils/post-login-redirect.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
import { OAUTH_AUTHORIZE_PENDING_KEY, REDIRECT_URL_KEY } from '@/app/account/oauth/authorize/page'
|
||||
import dayjs from 'dayjs'
|
||||
import type { ReadonlyURLSearchParams } from 'next/navigation'
|
||||
|
||||
function getItemWithExpiry(key: string): string | null {
|
||||
const itemStr = localStorage.getItem(key)
|
||||
if (!itemStr)
|
||||
return null
|
||||
|
||||
try {
|
||||
const item = JSON.parse(itemStr)
|
||||
localStorage.removeItem(key)
|
||||
if (!item?.value) return null
|
||||
|
||||
return dayjs().unix() > item.expiry ? null : item.value
|
||||
}
|
||||
catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export const resolvePostLoginRedirect = (searchParams: ReadonlyURLSearchParams) => {
|
||||
const redirectUrl = searchParams.get(REDIRECT_URL_KEY)
|
||||
if (redirectUrl) {
|
||||
try {
|
||||
localStorage.removeItem(OAUTH_AUTHORIZE_PENDING_KEY)
|
||||
return decodeURIComponent(redirectUrl)
|
||||
}
|
||||
catch (e) {
|
||||
console.error('Failed to decode redirect URL:', e)
|
||||
return redirectUrl
|
||||
}
|
||||
}
|
||||
|
||||
return getItemWithExpiry(OAUTH_AUTHORIZE_PENDING_KEY)
|
||||
}
|
@@ -34,6 +34,7 @@ const NAMESPACES = [
|
||||
'explore',
|
||||
'layout',
|
||||
'login',
|
||||
'oauth',
|
||||
'plugin-tags',
|
||||
'plugin',
|
||||
'register',
|
||||
|
27
web/i18n/en-US/oauth.ts
Normal file
27
web/i18n/en-US/oauth.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
const translation = {
|
||||
tips: {
|
||||
loggedIn: 'wants to access the following information from your Dify Cloud account.',
|
||||
notLoggedIn: 'wants to access your Dify Cloud account',
|
||||
needLogin: 'Please log in to authorize',
|
||||
common: 'We respect your privacy and will only use this information to enhance your experience with our developer tools.',
|
||||
},
|
||||
connect: 'Connect to',
|
||||
continue: 'Continue',
|
||||
switchAccount: 'Switch Account',
|
||||
login: 'Login',
|
||||
scopes: {
|
||||
name: 'Name',
|
||||
email: 'Email',
|
||||
avatar: 'Avatar',
|
||||
languagePreference: 'Language Preference',
|
||||
timezone: 'Timezone',
|
||||
},
|
||||
error: {
|
||||
invalidParams: 'Invalid parameters',
|
||||
authorizeFailed: 'Authorize failed',
|
||||
authAppInfoFetchFailed: 'Failed to fetch app info for authorization',
|
||||
},
|
||||
unknownApp: 'Unknown App',
|
||||
}
|
||||
|
||||
export default translation
|
27
web/i18n/zh-Hans/oauth.ts
Normal file
27
web/i18n/zh-Hans/oauth.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
const translation = {
|
||||
tips: {
|
||||
loggedIn: '想要访问您的 Dify Cloud 账号中的以下信息。',
|
||||
notLoggedIn: '想要访问您的 Dify Cloud 账号',
|
||||
needLogin: '请先登录以授权',
|
||||
common: '我们尊重您的隐私,并仅使用此信息来增强您对我们开发工具的使用体验。',
|
||||
},
|
||||
connect: '连接到',
|
||||
continue: '继续',
|
||||
switchAccount: '切换账号',
|
||||
login: '登录',
|
||||
scopes: {
|
||||
name: '名称',
|
||||
email: '邮箱',
|
||||
avatar: '头像',
|
||||
languagePreference: '语言偏好',
|
||||
timezone: '时区',
|
||||
},
|
||||
error: {
|
||||
invalidParams: '无效的参数',
|
||||
authorizeFailed: '授权失败',
|
||||
authAppInfoFetchFailed: '获取待授权应用的信息失败',
|
||||
},
|
||||
unknownApp: '未知应用',
|
||||
}
|
||||
|
||||
export default translation
|
29
web/service/use-oauth.ts
Normal file
29
web/service/use-oauth.ts
Normal file
@@ -0,0 +1,29 @@
|
||||
import { post } from './base'
|
||||
import { useMutation, useQuery } from '@tanstack/react-query'
|
||||
|
||||
const NAME_SPACE = 'oauth-provider'
|
||||
|
||||
export type OAuthAppInfo = {
|
||||
app_icon: string
|
||||
app_label: Record<string, string>
|
||||
scope: string
|
||||
}
|
||||
|
||||
export type OAuthAuthorizeResponse = {
|
||||
code: string
|
||||
}
|
||||
|
||||
export const useOAuthAppInfo = (client_id: string, redirect_uri: string) => {
|
||||
return useQuery<OAuthAppInfo>({
|
||||
queryKey: [NAME_SPACE, 'authAppInfo', client_id, redirect_uri],
|
||||
queryFn: () => post<OAuthAppInfo>('/oauth/provider', { body: { client_id, redirect_uri } }, { silent: true }),
|
||||
enabled: Boolean(client_id && redirect_uri),
|
||||
})
|
||||
}
|
||||
|
||||
export const useAuthorizeOAuthApp = () => {
|
||||
return useMutation({
|
||||
mutationKey: [NAME_SPACE, 'authorize'],
|
||||
mutationFn: (payload: { client_id: string }) => post<OAuthAuthorizeResponse>('/oauth/provider/authorize', { body: payload }),
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user