test: add comprehensive test suite for rate limiting module (#23765)

This commit is contained in:
Jason Young
2025-08-12 10:05:30 +08:00
committed by GitHub
parent a6c5b7414d
commit b38f195a0d
2 changed files with 693 additions and 0 deletions

View File

@@ -0,0 +1,124 @@
import time
from unittest.mock import MagicMock, patch
import pytest
from core.app.features.rate_limiting.rate_limit import RateLimit
@pytest.fixture
def mock_redis():
"""Mock Redis client with realistic behavior for rate limiting tests."""
mock_client = MagicMock()
# Redis data storage for simulation
mock_data = {}
mock_hashes = {}
mock_expiry = {}
def mock_setex(key, ttl, value):
mock_data[key] = str(value)
mock_expiry[key] = time.time() + ttl.total_seconds() if hasattr(ttl, "total_seconds") else time.time() + ttl
return True
def mock_get(key):
if key in mock_data and (key not in mock_expiry or time.time() < mock_expiry[key]):
return mock_data[key].encode("utf-8")
return None
def mock_exists(key):
return key in mock_data or key in mock_hashes
def mock_expire(key, ttl):
if key in mock_data or key in mock_hashes:
mock_expiry[key] = time.time() + ttl.total_seconds() if hasattr(ttl, "total_seconds") else time.time() + ttl
return True
def mock_hset(key, field, value):
if key not in mock_hashes:
mock_hashes[key] = {}
mock_hashes[key][field] = str(value).encode("utf-8")
return True
def mock_hgetall(key):
return mock_hashes.get(key, {})
def mock_hdel(key, *fields):
if key in mock_hashes:
count = 0
for field in fields:
if field in mock_hashes[key]:
del mock_hashes[key][field]
count += 1
return count
return 0
def mock_hlen(key):
return len(mock_hashes.get(key, {}))
# Configure mock methods
mock_client.setex = mock_setex
mock_client.get = mock_get
mock_client.exists = mock_exists
mock_client.expire = mock_expire
mock_client.hset = mock_hset
mock_client.hgetall = mock_hgetall
mock_client.hdel = mock_hdel
mock_client.hlen = mock_hlen
# Store references for test verification
mock_client._mock_data = mock_data
mock_client._mock_hashes = mock_hashes
mock_client._mock_expiry = mock_expiry
return mock_client
@pytest.fixture
def mock_time():
"""Mock time.time() for deterministic tests."""
mock_time_val = 1000.0
def increment_time(seconds=1):
nonlocal mock_time_val
mock_time_val += seconds
return mock_time_val
with patch("time.time", return_value=mock_time_val) as mock:
mock.increment = increment_time
yield mock
@pytest.fixture
def sample_generator():
"""Sample generator for testing RateLimitGenerator."""
def _create_generator(items=None, raise_error=False):
items = items or ["item1", "item2", "item3"]
for item in items:
if raise_error and item == "item2":
raise ValueError("Test error")
yield item
return _create_generator
@pytest.fixture
def sample_mapping():
"""Sample mapping for testing RateLimitGenerator."""
return {"key1": "value1", "key2": "value2"}
@pytest.fixture(autouse=True)
def reset_rate_limit_instances():
"""Clear RateLimit singleton instances between tests."""
RateLimit._instance_dict.clear()
yield
RateLimit._instance_dict.clear()
@pytest.fixture
def redis_patch():
"""Patch redis_client globally for rate limit tests."""
with patch("core.app.features.rate_limiting.rate_limit.redis_client") as mock:
yield mock

View File

@@ -0,0 +1,569 @@
import threading
import time
from datetime import timedelta
from unittest.mock import patch
import pytest
from core.app.features.rate_limiting.rate_limit import RateLimit
from core.errors.error import AppInvokeQuotaExceededError
class TestRateLimit:
"""Core rate limiting functionality tests."""
def test_should_return_same_instance_for_same_client_id(self, redis_patch):
"""Test singleton behavior for same client ID."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
rate_limit1 = RateLimit("client1", 5)
rate_limit2 = RateLimit("client1", 10) # Second instance with different limit
assert rate_limit1 is rate_limit2
# Current implementation: last constructor call overwrites max_active_requests
# This reflects the actual behavior where __init__ always sets max_active_requests
assert rate_limit1.max_active_requests == 10
def test_should_create_different_instances_for_different_client_ids(self, redis_patch):
"""Test different instances for different client IDs."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
rate_limit1 = RateLimit("client1", 5)
rate_limit2 = RateLimit("client2", 10)
assert rate_limit1 is not rate_limit2
assert rate_limit1.client_id == "client1"
assert rate_limit2.client_id == "client2"
def test_should_initialize_with_valid_parameters(self, redis_patch):
"""Test normal initialization."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
assert rate_limit.client_id == "test_client"
assert rate_limit.max_active_requests == 5
assert hasattr(rate_limit, "initialized")
redis_patch.setex.assert_called_once()
def test_should_skip_initialization_if_disabled(self):
"""Test no initialization when rate limiting is disabled."""
rate_limit = RateLimit("test_client", 0)
assert rate_limit.disabled()
assert not hasattr(rate_limit, "initialized")
def test_should_skip_reinitialization_of_existing_instance(self, redis_patch):
"""Test that existing instance doesn't reinitialize."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
RateLimit("client1", 5)
redis_patch.reset_mock()
RateLimit("client1", 10)
redis_patch.setex.assert_not_called()
def test_should_be_disabled_when_max_requests_is_zero_or_negative(self):
"""Test disabled state for zero or negative limits."""
rate_limit_zero = RateLimit("client1", 0)
rate_limit_negative = RateLimit("client2", -5)
assert rate_limit_zero.disabled()
assert rate_limit_negative.disabled()
def test_should_set_redis_keys_on_first_flush(self, redis_patch):
"""Test Redis keys are set correctly on initial flush."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
expected_max_key = "dify:rate_limit:test_client:max_active_requests"
redis_patch.setex.assert_called_with(expected_max_key, timedelta(days=1), 5)
def test_should_sync_max_requests_from_redis_on_subsequent_flush(self, redis_patch):
"""Test max requests syncs from Redis when key exists."""
redis_patch.configure_mock(
**{
"exists.return_value": True,
"get.return_value": b"10",
"expire.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
rate_limit.flush_cache()
assert rate_limit.max_active_requests == 10
@patch("time.time")
def test_should_clean_timeout_requests_from_active_list(self, mock_time, redis_patch):
"""Test cleanup of timed-out requests."""
current_time = 1000.0
mock_time.return_value = current_time
# Setup mock Redis with timed-out requests
timeout_requests = {
b"req1": str(current_time - 700).encode(), # 700 seconds ago (timeout)
b"req2": str(current_time - 100).encode(), # 100 seconds ago (active)
}
redis_patch.configure_mock(
**{
"exists.return_value": True,
"get.return_value": b"5",
"expire.return_value": True,
"hgetall.return_value": timeout_requests,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
redis_patch.reset_mock() # Reset to avoid counting initialization calls
rate_limit.flush_cache()
# Verify timeout request was cleaned up
redis_patch.hdel.assert_called_once()
call_args = redis_patch.hdel.call_args[0]
assert call_args[0] == "dify:rate_limit:test_client:active_requests"
assert b"req1" in call_args # Timeout request should be removed
assert b"req2" not in call_args # Active request should remain
class TestRateLimitEnterExit:
"""Rate limiting enter/exit logic tests."""
def test_should_allow_request_within_limit(self, redis_patch):
"""Test allowing requests within the rate limit."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 2,
"hset.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
request_id = rate_limit.enter()
assert request_id != RateLimit._UNLIMITED_REQUEST_ID
redis_patch.hset.assert_called_once()
def test_should_generate_request_id_if_not_provided(self, redis_patch):
"""Test auto-generation of request ID."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 0,
"hset.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
request_id = rate_limit.enter()
assert len(request_id) == 36 # UUID format
def test_should_use_provided_request_id(self, redis_patch):
"""Test using provided request ID."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 0,
"hset.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
custom_id = "custom_request_123"
request_id = rate_limit.enter(custom_id)
assert request_id == custom_id
def test_should_remove_request_on_exit(self, redis_patch):
"""Test request removal on exit."""
redis_patch.configure_mock(
**{
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
rate_limit.exit("test_request_id")
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", "test_request_id")
def test_should_raise_quota_exceeded_when_at_limit(self, redis_patch):
"""Test quota exceeded error when at limit."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 5, # At limit
}
)
rate_limit = RateLimit("test_client", 5)
with pytest.raises(AppInvokeQuotaExceededError) as exc_info:
rate_limit.enter()
assert "Too many requests" in str(exc_info.value)
assert "test_client" in str(exc_info.value)
def test_should_allow_request_after_previous_exit(self, redis_patch):
"""Test allowing new request after previous exit."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 4, # Under limit after exit
"hset.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
request_id = rate_limit.enter()
rate_limit.exit(request_id)
new_request_id = rate_limit.enter()
assert new_request_id is not None
@patch("time.time")
def test_should_flush_cache_when_interval_exceeded(self, mock_time, redis_patch):
"""Test cache flush when time interval exceeded."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 0,
}
)
mock_time.return_value = 1000.0
rate_limit = RateLimit("test_client", 5)
# Advance time beyond flush interval
mock_time.return_value = 1400.0 # 400 seconds later
redis_patch.reset_mock()
rate_limit.enter()
# Should have called setex again due to cache flush
redis_patch.setex.assert_called()
def test_should_return_unlimited_id_when_disabled(self):
"""Test unlimited ID return when rate limiting disabled."""
rate_limit = RateLimit("test_client", 0)
request_id = rate_limit.enter()
assert request_id == RateLimit._UNLIMITED_REQUEST_ID
def test_should_ignore_exit_for_unlimited_requests(self, redis_patch):
"""Test ignoring exit for unlimited requests."""
rate_limit = RateLimit("test_client", 0)
rate_limit.exit(RateLimit._UNLIMITED_REQUEST_ID)
redis_patch.hdel.assert_not_called()
class TestRateLimitGenerator:
"""Rate limit generator wrapper tests."""
def test_should_wrap_generator_and_iterate_normally(self, redis_patch, sample_generator):
"""Test normal generator iteration with rate limit wrapper."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
generator = sample_generator()
request_id = "test_request"
wrapped_gen = rate_limit.generate(generator, request_id)
result = list(wrapped_gen)
assert result == ["item1", "item2", "item3"]
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", request_id)
def test_should_handle_mapping_input_directly(self, sample_mapping):
"""Test direct return of mapping input."""
rate_limit = RateLimit("test_client", 0) # Disabled
result = rate_limit.generate(sample_mapping, "test_request")
assert result is sample_mapping
def test_should_cleanup_on_exception_during_iteration(self, redis_patch, sample_generator):
"""Test cleanup when exception occurs during iteration."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
generator = sample_generator(raise_error=True)
request_id = "test_request"
wrapped_gen = rate_limit.generate(generator, request_id)
with pytest.raises(ValueError):
list(wrapped_gen)
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", request_id)
def test_should_cleanup_on_explicit_close(self, redis_patch, sample_generator):
"""Test cleanup on explicit generator close."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
generator = sample_generator()
request_id = "test_request"
wrapped_gen = rate_limit.generate(generator, request_id)
wrapped_gen.close()
redis_patch.hdel.assert_called_once()
def test_should_handle_generator_without_close_method(self, redis_patch):
"""Test handling generator without close method."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
# Create a generator-like object without close method
class SimpleGenerator:
def __init__(self):
self.items = ["test"]
self.index = 0
def __iter__(self):
return self
def __next__(self):
if self.index >= len(self.items):
raise StopIteration
item = self.items[self.index]
self.index += 1
return item
rate_limit = RateLimit("test_client", 5)
generator = SimpleGenerator()
wrapped_gen = rate_limit.generate(generator, "test_request")
wrapped_gen.close() # Should not raise error
redis_patch.hdel.assert_called_once()
def test_should_prevent_iteration_after_close(self, redis_patch, sample_generator):
"""Test StopIteration after generator is closed."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
generator = sample_generator()
wrapped_gen = rate_limit.generate(generator, "test_request")
wrapped_gen.close()
with pytest.raises(StopIteration):
next(wrapped_gen)
class TestRateLimitConcurrency:
"""Concurrent access safety tests."""
def test_should_handle_concurrent_instance_creation(self, redis_patch):
"""Test thread-safe singleton instance creation."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
instances = []
errors = []
def create_instance():
try:
instance = RateLimit("concurrent_client", 5)
instances.append(instance)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=create_instance) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0
assert len({id(inst) for inst in instances}) == 1 # All same instance
def test_should_handle_concurrent_enter_requests(self, redis_patch):
"""Test concurrent enter requests handling."""
# Setup mock to simulate realistic Redis behavior
request_count = 0
def mock_hlen(key):
nonlocal request_count
return request_count
def mock_hset(key, field, value):
nonlocal request_count
request_count += 1
return True
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.side_effect": mock_hlen,
"hset.side_effect": mock_hset,
}
)
rate_limit = RateLimit("concurrent_client", 3)
results = []
errors = []
def try_enter():
try:
request_id = rate_limit.enter()
results.append(request_id)
except AppInvokeQuotaExceededError as e:
errors.append(e)
threads = [threading.Thread(target=try_enter) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
# Should have some successful requests and some quota exceeded
assert len(results) + len(errors) == 5
assert len(errors) > 0 # Some should be rejected
@patch("time.time")
def test_should_maintain_accurate_count_under_load(self, mock_time, redis_patch):
"""Test accurate count maintenance under concurrent load."""
mock_time.return_value = 1000.0
# Use real mock_redis fixture for better simulation
mock_client = self._create_mock_redis()
redis_patch.configure_mock(**mock_client)
rate_limit = RateLimit("load_test_client", 10)
active_requests = []
def enter_and_exit():
try:
request_id = rate_limit.enter()
active_requests.append(request_id)
time.sleep(0.01) # Simulate some work
rate_limit.exit(request_id)
active_requests.remove(request_id)
except AppInvokeQuotaExceededError:
pass # Expected under load
threads = [threading.Thread(target=enter_and_exit) for _ in range(20)]
for t in threads:
t.start()
for t in threads:
t.join()
# All requests should have been cleaned up
assert len(active_requests) == 0
def _create_mock_redis(self):
"""Create a thread-safe mock Redis for concurrency tests."""
import threading
lock = threading.Lock()
data = {}
hashes = {}
def mock_hlen(key):
with lock:
return len(hashes.get(key, {}))
def mock_hset(key, field, value):
with lock:
if key not in hashes:
hashes[key] = {}
hashes[key][field] = str(value).encode("utf-8")
return True
def mock_hdel(key, *fields):
with lock:
if key in hashes:
count = 0
for field in fields:
if field in hashes[key]:
del hashes[key][field]
count += 1
return count
return 0
return {
"exists.return_value": False,
"setex.return_value": True,
"hlen.side_effect": mock_hlen,
"hset.side_effect": mock_hset,
"hdel.side_effect": mock_hdel,
}