Co-authored-by: tech <cto@sb> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Union
|
||||
|
||||
import redis
|
||||
from redis import RedisError
|
||||
from redis.cache import CacheConfig
|
||||
from redis.cluster import ClusterNode, RedisCluster
|
||||
from redis.connection import Connection, SSLConnection
|
||||
@@ -9,6 +13,8 @@ from redis.sentinel import Sentinel
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisClientWrapper:
|
||||
"""
|
||||
@@ -115,3 +121,25 @@ def init_app(app: DifyApp):
|
||||
redis_client.initialize(redis.Redis(connection_pool=pool))
|
||||
|
||||
app.extensions["redis"] = redis_client
|
||||
|
||||
|
||||
def redis_fallback(default_return: Any = None):
|
||||
"""
|
||||
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
|
||||
|
||||
Args:
|
||||
default_return: The value to return when a Redis operation fails. Defaults to None.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except RedisError as e:
|
||||
logger.warning(f"Redis operation failed in {func.__name__}: {str(e)}", exc_info=True)
|
||||
return default_return
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
@@ -16,7 +16,7 @@ from configs import dify_config
|
||||
from constants.languages import language_timezone_mapping, languages
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
from libs.helper import RateLimiter, TokenManager
|
||||
from libs.passport import PassportService
|
||||
from libs.password import compare_password, hash_password, valid_password
|
||||
@@ -495,6 +495,7 @@ class AccountService:
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_login_error_rate_limit(email: str) -> None:
|
||||
key = f"login_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
@@ -504,6 +505,7 @@ class AccountService:
|
||||
redis_client.setex(key, dify_config.LOGIN_LOCKOUT_DURATION, count)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def is_login_error_rate_limit(email: str) -> bool:
|
||||
key = f"login_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
@@ -516,11 +518,13 @@ class AccountService:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def reset_login_error_rate_limit(email: str):
|
||||
key = f"login_error_rate_limit:{email}"
|
||||
redis_client.delete(key)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_forgot_password_error_rate_limit(email: str) -> None:
|
||||
key = f"forgot_password_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
@@ -530,6 +534,7 @@ class AccountService:
|
||||
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def is_forgot_password_error_rate_limit(email: str) -> bool:
|
||||
key = f"forgot_password_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
@@ -542,11 +547,13 @@ class AccountService:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def reset_forgot_password_error_rate_limit(email: str):
|
||||
key = f"forgot_password_error_rate_limit:{email}"
|
||||
redis_client.delete(key)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def is_email_send_ip_limit(ip_address: str):
|
||||
minute_key = f"email_send_ip_limit_minute:{ip_address}"
|
||||
freeze_key = f"email_send_ip_limit_freeze:{ip_address}"
|
||||
|
53
api/tests/unit_tests/extensions/test_redis.py
Normal file
53
api/tests/unit_tests/extensions/test_redis.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from redis import RedisError
|
||||
|
||||
from extensions.ext_redis import redis_fallback
|
||||
|
||||
|
||||
def test_redis_fallback_success():
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
return "success"
|
||||
|
||||
assert test_func() == "success"
|
||||
|
||||
|
||||
def test_redis_fallback_error():
|
||||
@redis_fallback(default_return="fallback")
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func() == "fallback"
|
||||
|
||||
|
||||
def test_redis_fallback_none_default():
|
||||
@redis_fallback()
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func() is None
|
||||
|
||||
|
||||
def test_redis_fallback_with_args():
|
||||
@redis_fallback(default_return=0)
|
||||
def test_func(x, y):
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func(1, 2) == 0
|
||||
|
||||
|
||||
def test_redis_fallback_with_kwargs():
|
||||
@redis_fallback(default_return={})
|
||||
def test_func(x=None, y=None):
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func(x=1, y=2) == {}
|
||||
|
||||
|
||||
def test_redis_fallback_preserves_function_metadata():
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
"""Test function docstring"""
|
||||
pass
|
||||
|
||||
assert test_func.__name__ == "test_func"
|
||||
assert test_func.__doc__ == "Test function docstring"
|
Reference in New Issue
Block a user