feat: implement forgot password feature (#5534)
This commit is contained in:
@@ -17,6 +17,10 @@ class SecurityConfig(BaseModel):
|
||||
default=None,
|
||||
)
|
||||
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
|
||||
description='Expiry time in hours for reset token',
|
||||
default=24,
|
||||
)
|
||||
|
||||
class AppExecutionConfig(BaseModel):
|
||||
"""
|
||||
|
@@ -30,7 +30,7 @@ from .app import (
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
|
||||
|
||||
# Import billing controllers
|
||||
from .billing import billing
|
||||
|
@@ -5,3 +5,28 @@ class ApiKeyAuthFailedError(BaseHTTPException):
|
||||
error_code = 'auth_failed'
|
||||
description = "{message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class InvalidEmailError(BaseHTTPException):
|
||||
error_code = 'invalid_email'
|
||||
description = "The email address is not valid."
|
||||
code = 400
|
||||
|
||||
|
||||
class PasswordMismatchError(BaseHTTPException):
|
||||
error_code = 'password_mismatch'
|
||||
description = "The passwords do not match."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvalidTokenError(BaseHTTPException):
|
||||
error_code = 'invalid_or_expired_token'
|
||||
description = "The token is invalid or has expired."
|
||||
code = 400
|
||||
|
||||
|
||||
class PasswordResetRateLimitExceededError(BaseHTTPException):
|
||||
error_code = 'password_reset_rate_limit_exceeded'
|
||||
description = "Password reset rate limit exceeded. Try again later."
|
||||
code = 429
|
||||
|
||||
|
107
api/controllers/console/auth/forgot_password.py
Normal file
107
api/controllers/console/auth/forgot_password.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import base64
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import (
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
PasswordResetRateLimitExceededError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
from services.errors.account import RateLimitExceededError
|
||||
|
||||
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('email', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
email = args['email']
|
||||
|
||||
if not email_validate(email):
|
||||
raise InvalidEmailError()
|
||||
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
|
||||
if account:
|
||||
try:
|
||||
AccountService.send_reset_password_email(account=account)
|
||||
except RateLimitExceededError:
|
||||
logging.warning(f"Rate limit exceeded for email: {account.email}")
|
||||
raise PasswordResetRateLimitExceededError()
|
||||
else:
|
||||
# Return success to avoid revealing email registration status
|
||||
logging.warning(f"Attempt to reset password for unregistered email: {email}")
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
token = args['token']
|
||||
|
||||
reset_data = AccountService.get_reset_password_data(token)
|
||||
|
||||
if reset_data is None:
|
||||
return {'is_valid': False, 'email': None}
|
||||
return {'is_valid': True, 'email': reset_data.get('email')}
|
||||
|
||||
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
|
||||
parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
new_password = args['new_password']
|
||||
password_confirm = args['password_confirm']
|
||||
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
raise PasswordMismatchError()
|
||||
|
||||
token = args['token']
|
||||
reset_data = AccountService.get_reset_password_data(token)
|
||||
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
AccountService.revoke_reset_password_token(token)
|
||||
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
account = Account.query.filter_by(email=reset_data.get('email')).first()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
|
||||
api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
|
||||
api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')
|
@@ -245,6 +245,8 @@ class AccountIntegrateApi(Resource):
|
||||
return {'data': integrate_data}
|
||||
|
||||
|
||||
|
||||
|
||||
# Register API resources
|
||||
api.add_resource(AccountInitApi, '/account/init')
|
||||
api.add_resource(AccountProfileApi, '/account/profile')
|
||||
|
@@ -1,18 +1,23 @@
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import Union
|
||||
from typing import Any, Optional, Union
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask import Response, current_app, stream_with_context
|
||||
from flask_restful import fields
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account
|
||||
|
||||
|
||||
def run(script):
|
||||
return subprocess.getstatusoutput('source /root/.bashrc && ' + script)
|
||||
@@ -46,12 +51,12 @@ def uuid_value(value):
|
||||
error = ('{value} is not a valid uuid.'
|
||||
.format(value=value))
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def alphanumeric(value: str):
|
||||
# check if the value is alphanumeric and underlined
|
||||
if re.match(r'^[a-zA-Z0-9_]+$', value):
|
||||
return value
|
||||
|
||||
|
||||
raise ValueError(f'{value} is not a valid alphanumeric value')
|
||||
|
||||
def timestamp_value(timestamp):
|
||||
@@ -163,3 +168,97 @@ def compact_generate_response(response: Union[dict, Generator]) -> Response:
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
|
||||
class TokenManager:
|
||||
|
||||
@classmethod
|
||||
def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
|
||||
old_token = cls._get_current_token_for_account(account.id, token_type)
|
||||
if old_token:
|
||||
if isinstance(old_token, bytes):
|
||||
old_token = old_token.decode('utf-8')
|
||||
cls.revoke_token(old_token, token_type)
|
||||
|
||||
token = str(uuid.uuid4())
|
||||
token_data = {
|
||||
'account_id': account.id,
|
||||
'email': account.email,
|
||||
'token_type': token_type
|
||||
}
|
||||
if additional_data:
|
||||
token_data.update(additional_data)
|
||||
|
||||
expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS']
|
||||
token_key = cls._get_token_key(token, token_type)
|
||||
redis_client.setex(
|
||||
token_key,
|
||||
expiry_hours * 60 * 60,
|
||||
json.dumps(token_data)
|
||||
)
|
||||
|
||||
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def _get_token_key(cls, token: str, token_type: str) -> str:
|
||||
return f'{token_type}:token:{token}'
|
||||
|
||||
@classmethod
|
||||
def revoke_token(cls, token: str, token_type: str):
|
||||
token_key = cls._get_token_key(token, token_type)
|
||||
redis_client.delete(token_key)
|
||||
|
||||
@classmethod
|
||||
def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]:
|
||||
key = cls._get_token_key(token, token_type)
|
||||
token_data_json = redis_client.get(key)
|
||||
if token_data_json is None:
|
||||
logging.warning(f"{token_type} token {token} not found with key {key}")
|
||||
return None
|
||||
token_data = json.loads(token_data_json)
|
||||
return token_data
|
||||
|
||||
@classmethod
|
||||
def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]:
|
||||
key = cls._get_account_token_key(account_id, token_type)
|
||||
current_token = redis_client.get(key)
|
||||
return current_token
|
||||
|
||||
@classmethod
|
||||
def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_hours: int):
|
||||
key = cls._get_account_token_key(account_id, token_type)
|
||||
redis_client.setex(key, expiry_hours * 60 * 60, token)
|
||||
|
||||
@classmethod
|
||||
def _get_account_token_key(cls, account_id: str, token_type: str) -> str:
|
||||
return f'{token_type}:account:{account_id}'
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self, prefix: str, max_attempts: int, time_window: int):
|
||||
self.prefix = prefix
|
||||
self.max_attempts = max_attempts
|
||||
self.time_window = time_window
|
||||
|
||||
def _get_key(self, email: str) -> str:
|
||||
return f"{self.prefix}:{email}"
|
||||
|
||||
def is_rate_limited(self, email: str) -> bool:
|
||||
key = self._get_key(email)
|
||||
current_time = int(time.time())
|
||||
window_start_time = current_time - self.time_window
|
||||
|
||||
redis_client.zremrangebyscore(key, '-inf', window_start_time)
|
||||
attempts = redis_client.zcard(key)
|
||||
|
||||
if attempts and int(attempts) >= self.max_attempts:
|
||||
return True
|
||||
return False
|
||||
|
||||
def increment_rate_limit(self, email: str):
|
||||
key = self._get_key(email)
|
||||
current_time = int(time.time())
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.expire(key, self.time_window * 2)
|
||||
|
@@ -13,6 +13,7 @@ from werkzeug.exceptions import Unauthorized
|
||||
from constants.languages import language_timezone_mapping, languages
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import RateLimiter, TokenManager
|
||||
from libs.passport import PassportService
|
||||
from libs.password import compare_password, hash_password, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
@@ -29,14 +30,22 @@ from services.errors.account import (
|
||||
LinkAccountIntegrateError,
|
||||
MemberNotInTenantError,
|
||||
NoPermissionError,
|
||||
RateLimitExceededError,
|
||||
RoleAlreadyAssignedError,
|
||||
TenantNotFound,
|
||||
)
|
||||
from tasks.mail_invite_member_task import send_invite_member_mail_task
|
||||
from tasks.mail_reset_password_task import send_reset_password_mail_task
|
||||
|
||||
|
||||
class AccountService:
|
||||
|
||||
reset_password_rate_limiter = RateLimiter(
|
||||
prefix="reset_password_rate_limit",
|
||||
max_attempts=5,
|
||||
time_window=60 * 60
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_user(user_id: str) -> Account:
|
||||
account = Account.query.filter_by(id=user_id).first()
|
||||
@@ -222,9 +231,33 @@ class AccountService:
|
||||
return None
|
||||
return AccountService.load_user(account_id)
|
||||
|
||||
@classmethod
|
||||
def send_reset_password_email(cls, account):
|
||||
if cls.reset_password_rate_limiter.is_rate_limited(account.email):
|
||||
raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.")
|
||||
|
||||
token = TokenManager.generate_token(account, 'reset_password')
|
||||
send_reset_password_mail_task.delay(
|
||||
language=account.interface_language,
|
||||
to=account.email,
|
||||
token=token
|
||||
)
|
||||
cls.reset_password_rate_limiter.increment_rate_limit(account.email)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def revoke_reset_password_token(cls, token: str):
|
||||
TokenManager.revoke_token(token, 'reset_password')
|
||||
|
||||
@classmethod
|
||||
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||
return TokenManager.get_token_data(token, 'reset_password')
|
||||
|
||||
|
||||
def _get_login_cache_key(*, account_id: str, token: str):
|
||||
return f"account_login:{account_id}:{token}"
|
||||
|
||||
|
||||
class TenantService:
|
||||
|
||||
@staticmethod
|
||||
|
@@ -51,3 +51,8 @@ class MemberNotInTenantError(BaseServiceError):
|
||||
|
||||
class RoleAlreadyAssignedError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitExceededError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
@@ -39,16 +39,15 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam
|
||||
mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content)
|
||||
else:
|
||||
html_content = render_template('invite_member_mail_template_en-US.html',
|
||||
to=to,
|
||||
inviter_name=inviter_name,
|
||||
workspace_name=workspace_name,
|
||||
url=url)
|
||||
to=to,
|
||||
inviter_name=inviter_name,
|
||||
workspace_name=workspace_name,
|
||||
url=url)
|
||||
mail.send(to=to, subject="Join Dify Workspace Now", html=html_content)
|
||||
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('Send invite member mail to {} succeeded: latency: {}'.format(to, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception:
|
||||
logging.exception("Send invite member mail to {} failed".format(to))
|
||||
logging.exception("Send invite member mail to {} failed".format(to))
|
44
api/tasks/mail_reset_password_task.py
Normal file
44
api/tasks/mail_reset_password_task.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from flask import current_app, render_template
|
||||
|
||||
from extensions.ext_mail import mail
|
||||
|
||||
|
||||
@shared_task(queue='mail')
|
||||
def send_reset_password_mail_task(language: str, to: str, token: str):
|
||||
"""
|
||||
Async Send reset password mail
|
||||
:param language: Language in which the email should be sent (e.g., 'en', 'zh')
|
||||
:param to: Recipient email address
|
||||
:param token: Reset password token to be included in the email
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logging.info(click.style('Start password reset mail to {}'.format(to), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
# send reset password mail using different languages
|
||||
try:
|
||||
url = f'{current_app.config.get("CONSOLE_WEB_URL")}/forgot-password?token={token}'
|
||||
if language == 'zh-Hans':
|
||||
html_content = render_template('reset_password_mail_template_zh-CN.html',
|
||||
to=to,
|
||||
url=url)
|
||||
mail.send(to=to, subject="重置您的 Dify 密码", html=html_content)
|
||||
else:
|
||||
html_content = render_template('reset_password_mail_template_en-US.html',
|
||||
to=to,
|
||||
url=url)
|
||||
mail.send(to=to, subject="Reset Your Dify Password", html=html_content)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('Send password reset mail to {} succeeded: latency: {}'.format(to, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception:
|
||||
logging.exception("Send password reset mail to {} failed".format(to))
|
72
api/templates/reset_password_mail_template_en-US.html
Normal file
72
api/templates/reset_password_mail_template_en-US.html
Normal file
@@ -0,0 +1,72 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
body {
|
||||
font-family: 'Arial', sans-serif;
|
||||
line-height: 16pt;
|
||||
color: #374151;
|
||||
background-color: #E5E7EB;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
.container {
|
||||
width: 100%;
|
||||
max-width: 560px;
|
||||
margin: 40px auto;
|
||||
padding: 20px;
|
||||
background-color: #F3F4F6;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
.header {
|
||||
text-align: center;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.header img {
|
||||
max-width: 100px;
|
||||
height: auto;
|
||||
}
|
||||
.button {
|
||||
display: inline-block;
|
||||
padding: 12px 24px;
|
||||
background-color: #2970FF;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
text-align: center;
|
||||
transition: background-color 0.3s ease;
|
||||
}
|
||||
.button:hover {
|
||||
background-color: #265DD4;
|
||||
}
|
||||
.footer {
|
||||
font-size: 0.9em;
|
||||
color: #777777;
|
||||
margin-top: 30px;
|
||||
}
|
||||
.content {
|
||||
margin-top: 20px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<img src="https://cloud.dify.ai/logo/logo-site.png" alt="Dify Logo">
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Dear {{ to }},</p>
|
||||
<p>We have received a request to reset your password. If you initiated this request, please click the button below to reset your password:</p>
|
||||
<p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">Reset Password</a></p>
|
||||
<p>If you did not request a password reset, please ignore this email and your account will remain secure.</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>Best regards,</p>
|
||||
<p>Dify Team</p>
|
||||
<p>Please do not reply directly to this email; it is automatically sent by the system.</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
72
api/templates/reset_password_mail_template_zh-CN.html
Normal file
72
api/templates/reset_password_mail_template_zh-CN.html
Normal file
@@ -0,0 +1,72 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
body {
|
||||
font-family: 'Arial', sans-serif;
|
||||
line-height: 16pt;
|
||||
color: #374151;
|
||||
background-color: #E5E7EB;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
.container {
|
||||
width: 100%;
|
||||
max-width: 560px;
|
||||
margin: 40px auto;
|
||||
padding: 20px;
|
||||
background-color: #F3F4F6;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
.header {
|
||||
text-align: center;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.header img {
|
||||
max-width: 100px;
|
||||
height: auto;
|
||||
}
|
||||
.button {
|
||||
display: inline-block;
|
||||
padding: 12px 24px;
|
||||
background-color: #2970FF;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
text-align: center;
|
||||
transition: background-color 0.3s ease;
|
||||
}
|
||||
.button:hover {
|
||||
background-color: #265DD4;
|
||||
}
|
||||
.footer {
|
||||
font-size: 0.9em;
|
||||
color: #777777;
|
||||
margin-top: 30px;
|
||||
}
|
||||
.content {
|
||||
margin-top: 20px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<img src="https://cloud.dify.ai/logo/logo-site.png" alt="Dify Logo">
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>尊敬的 {{ to }},</p>
|
||||
<p>我们收到了您关于重置密码的请求。如果是您本人操作,请点击以下按钮重置您的密码:</p>
|
||||
<p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">重置密码</a></p>
|
||||
<p>如果您没有请求重置密码,请忽略此邮件,您的账户信息将保持安全。</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>此致,</p>
|
||||
<p>Dify 团队</p>
|
||||
<p>请不要直接回复此电子邮件;由系统自动发送。</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
Reference in New Issue
Block a user