Co-authored-by: tech <cto@sb> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,10 @@
|
|||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
|
from redis import RedisError
|
||||||
from redis.cache import CacheConfig
|
from redis.cache import CacheConfig
|
||||||
from redis.cluster import ClusterNode, RedisCluster
|
from redis.cluster import ClusterNode, RedisCluster
|
||||||
from redis.connection import Connection, SSLConnection
|
from redis.connection import Connection, SSLConnection
|
||||||
@@ -9,6 +13,8 @@ from redis.sentinel import Sentinel
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RedisClientWrapper:
|
class RedisClientWrapper:
|
||||||
"""
|
"""
|
||||||
@@ -115,3 +121,25 @@ def init_app(app: DifyApp):
|
|||||||
redis_client.initialize(redis.Redis(connection_pool=pool))
|
redis_client.initialize(redis.Redis(connection_pool=pool))
|
||||||
|
|
||||||
app.extensions["redis"] = redis_client
|
app.extensions["redis"] = redis_client
|
||||||
|
|
||||||
|
|
||||||
|
def redis_fallback(default_return: Any = None):
|
||||||
|
"""
|
||||||
|
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_return: The value to return when a Redis operation fails. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except RedisError as e:
|
||||||
|
logger.warning(f"Redis operation failed in {func.__name__}: {str(e)}", exc_info=True)
|
||||||
|
return default_return
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
@@ -16,7 +16,7 @@ from configs import dify_config
|
|||||||
from constants.languages import language_timezone_mapping, languages
|
from constants.languages import language_timezone_mapping, languages
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client, redis_fallback
|
||||||
from libs.helper import RateLimiter, TokenManager
|
from libs.helper import RateLimiter, TokenManager
|
||||||
from libs.passport import PassportService
|
from libs.passport import PassportService
|
||||||
from libs.password import compare_password, hash_password, valid_password
|
from libs.password import compare_password, hash_password, valid_password
|
||||||
@@ -495,6 +495,7 @@ class AccountService:
|
|||||||
return account
|
return account
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@redis_fallback(default_return=None)
|
||||||
def add_login_error_rate_limit(email: str) -> None:
|
def add_login_error_rate_limit(email: str) -> None:
|
||||||
key = f"login_error_rate_limit:{email}"
|
key = f"login_error_rate_limit:{email}"
|
||||||
count = redis_client.get(key)
|
count = redis_client.get(key)
|
||||||
@@ -504,6 +505,7 @@ class AccountService:
|
|||||||
redis_client.setex(key, dify_config.LOGIN_LOCKOUT_DURATION, count)
|
redis_client.setex(key, dify_config.LOGIN_LOCKOUT_DURATION, count)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@redis_fallback(default_return=False)
|
||||||
def is_login_error_rate_limit(email: str) -> bool:
|
def is_login_error_rate_limit(email: str) -> bool:
|
||||||
key = f"login_error_rate_limit:{email}"
|
key = f"login_error_rate_limit:{email}"
|
||||||
count = redis_client.get(key)
|
count = redis_client.get(key)
|
||||||
@@ -516,11 +518,13 @@ class AccountService:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@redis_fallback(default_return=None)
|
||||||
def reset_login_error_rate_limit(email: str):
|
def reset_login_error_rate_limit(email: str):
|
||||||
key = f"login_error_rate_limit:{email}"
|
key = f"login_error_rate_limit:{email}"
|
||||||
redis_client.delete(key)
|
redis_client.delete(key)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@redis_fallback(default_return=None)
|
||||||
def add_forgot_password_error_rate_limit(email: str) -> None:
|
def add_forgot_password_error_rate_limit(email: str) -> None:
|
||||||
key = f"forgot_password_error_rate_limit:{email}"
|
key = f"forgot_password_error_rate_limit:{email}"
|
||||||
count = redis_client.get(key)
|
count = redis_client.get(key)
|
||||||
@@ -530,6 +534,7 @@ class AccountService:
|
|||||||
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
|
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@redis_fallback(default_return=False)
|
||||||
def is_forgot_password_error_rate_limit(email: str) -> bool:
|
def is_forgot_password_error_rate_limit(email: str) -> bool:
|
||||||
key = f"forgot_password_error_rate_limit:{email}"
|
key = f"forgot_password_error_rate_limit:{email}"
|
||||||
count = redis_client.get(key)
|
count = redis_client.get(key)
|
||||||
@@ -542,11 +547,13 @@ class AccountService:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@redis_fallback(default_return=None)
|
||||||
def reset_forgot_password_error_rate_limit(email: str):
|
def reset_forgot_password_error_rate_limit(email: str):
|
||||||
key = f"forgot_password_error_rate_limit:{email}"
|
key = f"forgot_password_error_rate_limit:{email}"
|
||||||
redis_client.delete(key)
|
redis_client.delete(key)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@redis_fallback(default_return=False)
|
||||||
def is_email_send_ip_limit(ip_address: str):
|
def is_email_send_ip_limit(ip_address: str):
|
||||||
minute_key = f"email_send_ip_limit_minute:{ip_address}"
|
minute_key = f"email_send_ip_limit_minute:{ip_address}"
|
||||||
freeze_key = f"email_send_ip_limit_freeze:{ip_address}"
|
freeze_key = f"email_send_ip_limit_freeze:{ip_address}"
|
||||||
|
53
api/tests/unit_tests/extensions/test_redis.py
Normal file
53
api/tests/unit_tests/extensions/test_redis.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from redis import RedisError
|
||||||
|
|
||||||
|
from extensions.ext_redis import redis_fallback
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_fallback_success():
|
||||||
|
@redis_fallback(default_return=None)
|
||||||
|
def test_func():
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
assert test_func() == "success"
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_fallback_error():
|
||||||
|
@redis_fallback(default_return="fallback")
|
||||||
|
def test_func():
|
||||||
|
raise RedisError("Redis error")
|
||||||
|
|
||||||
|
assert test_func() == "fallback"
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_fallback_none_default():
|
||||||
|
@redis_fallback()
|
||||||
|
def test_func():
|
||||||
|
raise RedisError("Redis error")
|
||||||
|
|
||||||
|
assert test_func() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_fallback_with_args():
|
||||||
|
@redis_fallback(default_return=0)
|
||||||
|
def test_func(x, y):
|
||||||
|
raise RedisError("Redis error")
|
||||||
|
|
||||||
|
assert test_func(1, 2) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_fallback_with_kwargs():
|
||||||
|
@redis_fallback(default_return={})
|
||||||
|
def test_func(x=None, y=None):
|
||||||
|
raise RedisError("Redis error")
|
||||||
|
|
||||||
|
assert test_func(x=1, y=2) == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_fallback_preserves_function_metadata():
|
||||||
|
@redis_fallback(default_return=None)
|
||||||
|
def test_func():
|
||||||
|
"""Test function docstring"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert test_func.__name__ == "test_func"
|
||||||
|
assert test_func.__doc__ == "Test function docstring"
|
Reference in New Issue
Block a user