feat: add redis fallback mechanism #21043 (#21044)

Co-authored-by: tech <cto@sb>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
NeatGuyCoding
2025-07-10 10:19:58 +08:00
committed by GitHub
parent a371390d6c
commit 6f8c7a66c8
3 changed files with 89 additions and 1 deletions

View File

@@ -1,6 +1,10 @@
import functools
import logging
from collections.abc import Callable
from typing import Any, Union from typing import Any, Union
import redis import redis
from redis import RedisError
from redis.cache import CacheConfig from redis.cache import CacheConfig
from redis.cluster import ClusterNode, RedisCluster from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection from redis.connection import Connection, SSLConnection
@@ -9,6 +13,8 @@ from redis.sentinel import Sentinel
from configs import dify_config from configs import dify_config
from dify_app import DifyApp from dify_app import DifyApp
logger = logging.getLogger(__name__)
class RedisClientWrapper: class RedisClientWrapper:
""" """
@@ -115,3 +121,25 @@ def init_app(app: DifyApp):
redis_client.initialize(redis.Redis(connection_pool=pool)) redis_client.initialize(redis.Redis(connection_pool=pool))
app.extensions["redis"] = redis_client app.extensions["redis"] = redis_client
def redis_fallback(default_return: Any = None):
"""
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
Args:
default_return: The value to return when a Redis operation fails. Defaults to None.
"""
def decorator(func: Callable):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except RedisError as e:
logger.warning(f"Redis operation failed in {func.__name__}: {str(e)}", exc_info=True)
return default_return
return wrapper
return decorator

View File

@@ -16,7 +16,7 @@ from configs import dify_config
from constants.languages import language_timezone_mapping, languages from constants.languages import language_timezone_mapping, languages
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client, redis_fallback
from libs.helper import RateLimiter, TokenManager from libs.helper import RateLimiter, TokenManager
from libs.passport import PassportService from libs.passport import PassportService
from libs.password import compare_password, hash_password, valid_password from libs.password import compare_password, hash_password, valid_password
@@ -495,6 +495,7 @@ class AccountService:
return account return account
@staticmethod @staticmethod
@redis_fallback(default_return=None)
def add_login_error_rate_limit(email: str) -> None: def add_login_error_rate_limit(email: str) -> None:
key = f"login_error_rate_limit:{email}" key = f"login_error_rate_limit:{email}"
count = redis_client.get(key) count = redis_client.get(key)
@@ -504,6 +505,7 @@ class AccountService:
redis_client.setex(key, dify_config.LOGIN_LOCKOUT_DURATION, count) redis_client.setex(key, dify_config.LOGIN_LOCKOUT_DURATION, count)
@staticmethod @staticmethod
@redis_fallback(default_return=False)
def is_login_error_rate_limit(email: str) -> bool: def is_login_error_rate_limit(email: str) -> bool:
key = f"login_error_rate_limit:{email}" key = f"login_error_rate_limit:{email}"
count = redis_client.get(key) count = redis_client.get(key)
@@ -516,11 +518,13 @@ class AccountService:
return False return False
@staticmethod @staticmethod
@redis_fallback(default_return=None)
def reset_login_error_rate_limit(email: str): def reset_login_error_rate_limit(email: str):
key = f"login_error_rate_limit:{email}" key = f"login_error_rate_limit:{email}"
redis_client.delete(key) redis_client.delete(key)
@staticmethod @staticmethod
@redis_fallback(default_return=None)
def add_forgot_password_error_rate_limit(email: str) -> None: def add_forgot_password_error_rate_limit(email: str) -> None:
key = f"forgot_password_error_rate_limit:{email}" key = f"forgot_password_error_rate_limit:{email}"
count = redis_client.get(key) count = redis_client.get(key)
@@ -530,6 +534,7 @@ class AccountService:
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count) redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
@staticmethod @staticmethod
@redis_fallback(default_return=False)
def is_forgot_password_error_rate_limit(email: str) -> bool: def is_forgot_password_error_rate_limit(email: str) -> bool:
key = f"forgot_password_error_rate_limit:{email}" key = f"forgot_password_error_rate_limit:{email}"
count = redis_client.get(key) count = redis_client.get(key)
@@ -542,11 +547,13 @@ class AccountService:
return False return False
@staticmethod @staticmethod
@redis_fallback(default_return=None)
def reset_forgot_password_error_rate_limit(email: str): def reset_forgot_password_error_rate_limit(email: str):
key = f"forgot_password_error_rate_limit:{email}" key = f"forgot_password_error_rate_limit:{email}"
redis_client.delete(key) redis_client.delete(key)
@staticmethod @staticmethod
@redis_fallback(default_return=False)
def is_email_send_ip_limit(ip_address: str): def is_email_send_ip_limit(ip_address: str):
minute_key = f"email_send_ip_limit_minute:{ip_address}" minute_key = f"email_send_ip_limit_minute:{ip_address}"
freeze_key = f"email_send_ip_limit_freeze:{ip_address}" freeze_key = f"email_send_ip_limit_freeze:{ip_address}"

View File

@@ -0,0 +1,53 @@
from redis import RedisError
from extensions.ext_redis import redis_fallback
def test_redis_fallback_success():
@redis_fallback(default_return=None)
def test_func():
return "success"
assert test_func() == "success"
def test_redis_fallback_error():
@redis_fallback(default_return="fallback")
def test_func():
raise RedisError("Redis error")
assert test_func() == "fallback"
def test_redis_fallback_none_default():
@redis_fallback()
def test_func():
raise RedisError("Redis error")
assert test_func() is None
def test_redis_fallback_with_args():
@redis_fallback(default_return=0)
def test_func(x, y):
raise RedisError("Redis error")
assert test_func(1, 2) == 0
def test_redis_fallback_with_kwargs():
@redis_fallback(default_return={})
def test_func(x=None, y=None):
raise RedisError("Redis error")
assert test_func(x=1, y=2) == {}
def test_redis_fallback_preserves_function_metadata():
@redis_fallback(default_return=None)
def test_func():
"""Test function docstring"""
pass
assert test_func.__name__ == "test_func"
assert test_func.__doc__ == "Test function docstring"