test: add comprehensive test suite for rate limiting module (#23765)
This commit is contained in:
124
api/tests/unit_tests/core/app/features/rate_limiting/conftest.py
Normal file
124
api/tests/unit_tests/core/app/features/rate_limiting/conftest.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
"""Mock Redis client with realistic behavior for rate limiting tests."""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Redis data storage for simulation
|
||||
mock_data = {}
|
||||
mock_hashes = {}
|
||||
mock_expiry = {}
|
||||
|
||||
def mock_setex(key, ttl, value):
|
||||
mock_data[key] = str(value)
|
||||
mock_expiry[key] = time.time() + ttl.total_seconds() if hasattr(ttl, "total_seconds") else time.time() + ttl
|
||||
return True
|
||||
|
||||
def mock_get(key):
|
||||
if key in mock_data and (key not in mock_expiry or time.time() < mock_expiry[key]):
|
||||
return mock_data[key].encode("utf-8")
|
||||
return None
|
||||
|
||||
def mock_exists(key):
|
||||
return key in mock_data or key in mock_hashes
|
||||
|
||||
def mock_expire(key, ttl):
|
||||
if key in mock_data or key in mock_hashes:
|
||||
mock_expiry[key] = time.time() + ttl.total_seconds() if hasattr(ttl, "total_seconds") else time.time() + ttl
|
||||
return True
|
||||
|
||||
def mock_hset(key, field, value):
|
||||
if key not in mock_hashes:
|
||||
mock_hashes[key] = {}
|
||||
mock_hashes[key][field] = str(value).encode("utf-8")
|
||||
return True
|
||||
|
||||
def mock_hgetall(key):
|
||||
return mock_hashes.get(key, {})
|
||||
|
||||
def mock_hdel(key, *fields):
|
||||
if key in mock_hashes:
|
||||
count = 0
|
||||
for field in fields:
|
||||
if field in mock_hashes[key]:
|
||||
del mock_hashes[key][field]
|
||||
count += 1
|
||||
return count
|
||||
return 0
|
||||
|
||||
def mock_hlen(key):
|
||||
return len(mock_hashes.get(key, {}))
|
||||
|
||||
# Configure mock methods
|
||||
mock_client.setex = mock_setex
|
||||
mock_client.get = mock_get
|
||||
mock_client.exists = mock_exists
|
||||
mock_client.expire = mock_expire
|
||||
mock_client.hset = mock_hset
|
||||
mock_client.hgetall = mock_hgetall
|
||||
mock_client.hdel = mock_hdel
|
||||
mock_client.hlen = mock_hlen
|
||||
|
||||
# Store references for test verification
|
||||
mock_client._mock_data = mock_data
|
||||
mock_client._mock_hashes = mock_hashes
|
||||
mock_client._mock_expiry = mock_expiry
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_time():
|
||||
"""Mock time.time() for deterministic tests."""
|
||||
mock_time_val = 1000.0
|
||||
|
||||
def increment_time(seconds=1):
|
||||
nonlocal mock_time_val
|
||||
mock_time_val += seconds
|
||||
return mock_time_val
|
||||
|
||||
with patch("time.time", return_value=mock_time_val) as mock:
|
||||
mock.increment = increment_time
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_generator():
|
||||
"""Sample generator for testing RateLimitGenerator."""
|
||||
|
||||
def _create_generator(items=None, raise_error=False):
|
||||
items = items or ["item1", "item2", "item3"]
|
||||
for item in items:
|
||||
if raise_error and item == "item2":
|
||||
raise ValueError("Test error")
|
||||
yield item
|
||||
|
||||
return _create_generator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_mapping():
|
||||
"""Sample mapping for testing RateLimitGenerator."""
|
||||
return {"key1": "value1", "key2": "value2"}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_rate_limit_instances():
|
||||
"""Clear RateLimit singleton instances between tests."""
|
||||
RateLimit._instance_dict.clear()
|
||||
yield
|
||||
RateLimit._instance_dict.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_patch():
|
||||
"""Patch redis_client globally for rate limit tests."""
|
||||
with patch("core.app.features.rate_limiting.rate_limit.redis_client") as mock:
|
||||
yield mock
|
@@ -0,0 +1,569 @@
|
||||
import threading
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimit
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
|
||||
|
||||
class TestRateLimit:
|
||||
"""Core rate limiting functionality tests."""
|
||||
|
||||
def test_should_return_same_instance_for_same_client_id(self, redis_patch):
|
||||
"""Test singleton behavior for same client ID."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit1 = RateLimit("client1", 5)
|
||||
rate_limit2 = RateLimit("client1", 10) # Second instance with different limit
|
||||
|
||||
assert rate_limit1 is rate_limit2
|
||||
# Current implementation: last constructor call overwrites max_active_requests
|
||||
# This reflects the actual behavior where __init__ always sets max_active_requests
|
||||
assert rate_limit1.max_active_requests == 10
|
||||
|
||||
def test_should_create_different_instances_for_different_client_ids(self, redis_patch):
|
||||
"""Test different instances for different client IDs."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit1 = RateLimit("client1", 5)
|
||||
rate_limit2 = RateLimit("client2", 10)
|
||||
|
||||
assert rate_limit1 is not rate_limit2
|
||||
assert rate_limit1.client_id == "client1"
|
||||
assert rate_limit2.client_id == "client2"
|
||||
|
||||
def test_should_initialize_with_valid_parameters(self, redis_patch):
|
||||
"""Test normal initialization."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
assert rate_limit.client_id == "test_client"
|
||||
assert rate_limit.max_active_requests == 5
|
||||
assert hasattr(rate_limit, "initialized")
|
||||
redis_patch.setex.assert_called_once()
|
||||
|
||||
def test_should_skip_initialization_if_disabled(self):
|
||||
"""Test no initialization when rate limiting is disabled."""
|
||||
rate_limit = RateLimit("test_client", 0)
|
||||
|
||||
assert rate_limit.disabled()
|
||||
assert not hasattr(rate_limit, "initialized")
|
||||
|
||||
def test_should_skip_reinitialization_of_existing_instance(self, redis_patch):
|
||||
"""Test that existing instance doesn't reinitialize."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
RateLimit("client1", 5)
|
||||
redis_patch.reset_mock()
|
||||
|
||||
RateLimit("client1", 10)
|
||||
|
||||
redis_patch.setex.assert_not_called()
|
||||
|
||||
def test_should_be_disabled_when_max_requests_is_zero_or_negative(self):
|
||||
"""Test disabled state for zero or negative limits."""
|
||||
rate_limit_zero = RateLimit("client1", 0)
|
||||
rate_limit_negative = RateLimit("client2", -5)
|
||||
|
||||
assert rate_limit_zero.disabled()
|
||||
assert rate_limit_negative.disabled()
|
||||
|
||||
def test_should_set_redis_keys_on_first_flush(self, redis_patch):
|
||||
"""Test Redis keys are set correctly on initial flush."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
expected_max_key = "dify:rate_limit:test_client:max_active_requests"
|
||||
redis_patch.setex.assert_called_with(expected_max_key, timedelta(days=1), 5)
|
||||
|
||||
def test_should_sync_max_requests_from_redis_on_subsequent_flush(self, redis_patch):
|
||||
"""Test max requests syncs from Redis when key exists."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": True,
|
||||
"get.return_value": b"10",
|
||||
"expire.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
rate_limit.flush_cache()
|
||||
|
||||
assert rate_limit.max_active_requests == 10
|
||||
|
||||
@patch("time.time")
|
||||
def test_should_clean_timeout_requests_from_active_list(self, mock_time, redis_patch):
|
||||
"""Test cleanup of timed-out requests."""
|
||||
current_time = 1000.0
|
||||
mock_time.return_value = current_time
|
||||
|
||||
# Setup mock Redis with timed-out requests
|
||||
timeout_requests = {
|
||||
b"req1": str(current_time - 700).encode(), # 700 seconds ago (timeout)
|
||||
b"req2": str(current_time - 100).encode(), # 100 seconds ago (active)
|
||||
}
|
||||
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": True,
|
||||
"get.return_value": b"5",
|
||||
"expire.return_value": True,
|
||||
"hgetall.return_value": timeout_requests,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
redis_patch.reset_mock() # Reset to avoid counting initialization calls
|
||||
rate_limit.flush_cache()
|
||||
|
||||
# Verify timeout request was cleaned up
|
||||
redis_patch.hdel.assert_called_once()
|
||||
call_args = redis_patch.hdel.call_args[0]
|
||||
assert call_args[0] == "dify:rate_limit:test_client:active_requests"
|
||||
assert b"req1" in call_args # Timeout request should be removed
|
||||
assert b"req2" not in call_args # Active request should remain
|
||||
|
||||
|
||||
class TestRateLimitEnterExit:
|
||||
"""Rate limiting enter/exit logic tests."""
|
||||
|
||||
def test_should_allow_request_within_limit(self, redis_patch):
|
||||
"""Test allowing requests within the rate limit."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 2,
|
||||
"hset.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
request_id = rate_limit.enter()
|
||||
|
||||
assert request_id != RateLimit._UNLIMITED_REQUEST_ID
|
||||
redis_patch.hset.assert_called_once()
|
||||
|
||||
def test_should_generate_request_id_if_not_provided(self, redis_patch):
|
||||
"""Test auto-generation of request ID."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 0,
|
||||
"hset.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
request_id = rate_limit.enter()
|
||||
|
||||
assert len(request_id) == 36 # UUID format
|
||||
|
||||
def test_should_use_provided_request_id(self, redis_patch):
|
||||
"""Test using provided request ID."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 0,
|
||||
"hset.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
custom_id = "custom_request_123"
|
||||
request_id = rate_limit.enter(custom_id)
|
||||
|
||||
assert request_id == custom_id
|
||||
|
||||
def test_should_remove_request_on_exit(self, redis_patch):
|
||||
"""Test request removal on exit."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
rate_limit.exit("test_request_id")
|
||||
|
||||
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", "test_request_id")
|
||||
|
||||
def test_should_raise_quota_exceeded_when_at_limit(self, redis_patch):
|
||||
"""Test quota exceeded error when at limit."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 5, # At limit
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
with pytest.raises(AppInvokeQuotaExceededError) as exc_info:
|
||||
rate_limit.enter()
|
||||
|
||||
assert "Too many requests" in str(exc_info.value)
|
||||
assert "test_client" in str(exc_info.value)
|
||||
|
||||
def test_should_allow_request_after_previous_exit(self, redis_patch):
|
||||
"""Test allowing new request after previous exit."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 4, # Under limit after exit
|
||||
"hset.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
request_id = rate_limit.enter()
|
||||
rate_limit.exit(request_id)
|
||||
|
||||
new_request_id = rate_limit.enter()
|
||||
assert new_request_id is not None
|
||||
|
||||
@patch("time.time")
|
||||
def test_should_flush_cache_when_interval_exceeded(self, mock_time, redis_patch):
|
||||
"""Test cache flush when time interval exceeded."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 0,
|
||||
}
|
||||
)
|
||||
|
||||
mock_time.return_value = 1000.0
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
# Advance time beyond flush interval
|
||||
mock_time.return_value = 1400.0 # 400 seconds later
|
||||
redis_patch.reset_mock()
|
||||
|
||||
rate_limit.enter()
|
||||
|
||||
# Should have called setex again due to cache flush
|
||||
redis_patch.setex.assert_called()
|
||||
|
||||
def test_should_return_unlimited_id_when_disabled(self):
|
||||
"""Test unlimited ID return when rate limiting disabled."""
|
||||
rate_limit = RateLimit("test_client", 0)
|
||||
request_id = rate_limit.enter()
|
||||
|
||||
assert request_id == RateLimit._UNLIMITED_REQUEST_ID
|
||||
|
||||
def test_should_ignore_exit_for_unlimited_requests(self, redis_patch):
|
||||
"""Test ignoring exit for unlimited requests."""
|
||||
rate_limit = RateLimit("test_client", 0)
|
||||
rate_limit.exit(RateLimit._UNLIMITED_REQUEST_ID)
|
||||
|
||||
redis_patch.hdel.assert_not_called()
|
||||
|
||||
|
||||
class TestRateLimitGenerator:
|
||||
"""Rate limit generator wrapper tests."""
|
||||
|
||||
def test_should_wrap_generator_and_iterate_normally(self, redis_patch, sample_generator):
|
||||
"""Test normal generator iteration with rate limit wrapper."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = sample_generator()
|
||||
request_id = "test_request"
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, request_id)
|
||||
result = list(wrapped_gen)
|
||||
|
||||
assert result == ["item1", "item2", "item3"]
|
||||
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", request_id)
|
||||
|
||||
def test_should_handle_mapping_input_directly(self, sample_mapping):
|
||||
"""Test direct return of mapping input."""
|
||||
rate_limit = RateLimit("test_client", 0) # Disabled
|
||||
result = rate_limit.generate(sample_mapping, "test_request")
|
||||
|
||||
assert result is sample_mapping
|
||||
|
||||
def test_should_cleanup_on_exception_during_iteration(self, redis_patch, sample_generator):
|
||||
"""Test cleanup when exception occurs during iteration."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = sample_generator(raise_error=True)
|
||||
request_id = "test_request"
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, request_id)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
list(wrapped_gen)
|
||||
|
||||
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", request_id)
|
||||
|
||||
def test_should_cleanup_on_explicit_close(self, redis_patch, sample_generator):
|
||||
"""Test cleanup on explicit generator close."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = sample_generator()
|
||||
request_id = "test_request"
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, request_id)
|
||||
wrapped_gen.close()
|
||||
|
||||
redis_patch.hdel.assert_called_once()
|
||||
|
||||
def test_should_handle_generator_without_close_method(self, redis_patch):
|
||||
"""Test handling generator without close method."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
# Create a generator-like object without close method
|
||||
class SimpleGenerator:
|
||||
def __init__(self):
|
||||
self.items = ["test"]
|
||||
self.index = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= len(self.items):
|
||||
raise StopIteration
|
||||
item = self.items[self.index]
|
||||
self.index += 1
|
||||
return item
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = SimpleGenerator()
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, "test_request")
|
||||
wrapped_gen.close() # Should not raise error
|
||||
|
||||
redis_patch.hdel.assert_called_once()
|
||||
|
||||
def test_should_prevent_iteration_after_close(self, redis_patch, sample_generator):
|
||||
"""Test StopIteration after generator is closed."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = sample_generator()
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, "test_request")
|
||||
wrapped_gen.close()
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
next(wrapped_gen)
|
||||
|
||||
|
||||
class TestRateLimitConcurrency:
|
||||
"""Concurrent access safety tests."""
|
||||
|
||||
def test_should_handle_concurrent_instance_creation(self, redis_patch):
|
||||
"""Test thread-safe singleton instance creation."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
instances = []
|
||||
errors = []
|
||||
|
||||
def create_instance():
|
||||
try:
|
||||
instance = RateLimit("concurrent_client", 5)
|
||||
instances.append(instance)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=create_instance) for _ in range(10)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0
|
||||
assert len({id(inst) for inst in instances}) == 1 # All same instance
|
||||
|
||||
def test_should_handle_concurrent_enter_requests(self, redis_patch):
|
||||
"""Test concurrent enter requests handling."""
|
||||
# Setup mock to simulate realistic Redis behavior
|
||||
request_count = 0
|
||||
|
||||
def mock_hlen(key):
|
||||
nonlocal request_count
|
||||
return request_count
|
||||
|
||||
def mock_hset(key, field, value):
|
||||
nonlocal request_count
|
||||
request_count += 1
|
||||
return True
|
||||
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.side_effect": mock_hlen,
|
||||
"hset.side_effect": mock_hset,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("concurrent_client", 3)
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def try_enter():
|
||||
try:
|
||||
request_id = rate_limit.enter()
|
||||
results.append(request_id)
|
||||
except AppInvokeQuotaExceededError as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=try_enter) for _ in range(5)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Should have some successful requests and some quota exceeded
|
||||
assert len(results) + len(errors) == 5
|
||||
assert len(errors) > 0 # Some should be rejected
|
||||
|
||||
@patch("time.time")
|
||||
def test_should_maintain_accurate_count_under_load(self, mock_time, redis_patch):
|
||||
"""Test accurate count maintenance under concurrent load."""
|
||||
mock_time.return_value = 1000.0
|
||||
|
||||
# Use real mock_redis fixture for better simulation
|
||||
mock_client = self._create_mock_redis()
|
||||
redis_patch.configure_mock(**mock_client)
|
||||
|
||||
rate_limit = RateLimit("load_test_client", 10)
|
||||
active_requests = []
|
||||
|
||||
def enter_and_exit():
|
||||
try:
|
||||
request_id = rate_limit.enter()
|
||||
active_requests.append(request_id)
|
||||
time.sleep(0.01) # Simulate some work
|
||||
rate_limit.exit(request_id)
|
||||
active_requests.remove(request_id)
|
||||
except AppInvokeQuotaExceededError:
|
||||
pass # Expected under load
|
||||
|
||||
threads = [threading.Thread(target=enter_and_exit) for _ in range(20)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# All requests should have been cleaned up
|
||||
assert len(active_requests) == 0
|
||||
|
||||
def _create_mock_redis(self):
|
||||
"""Create a thread-safe mock Redis for concurrency tests."""
|
||||
import threading
|
||||
|
||||
lock = threading.Lock()
|
||||
data = {}
|
||||
hashes = {}
|
||||
|
||||
def mock_hlen(key):
|
||||
with lock:
|
||||
return len(hashes.get(key, {}))
|
||||
|
||||
def mock_hset(key, field, value):
|
||||
with lock:
|
||||
if key not in hashes:
|
||||
hashes[key] = {}
|
||||
hashes[key][field] = str(value).encode("utf-8")
|
||||
return True
|
||||
|
||||
def mock_hdel(key, *fields):
|
||||
with lock:
|
||||
if key in hashes:
|
||||
count = 0
|
||||
for field in fields:
|
||||
if field in hashes[key]:
|
||||
del hashes[key][field]
|
||||
count += 1
|
||||
return count
|
||||
return 0
|
||||
|
||||
return {
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.side_effect": mock_hlen,
|
||||
"hset.side_effect": mock_hset,
|
||||
"hdel.side_effect": mock_hdel,
|
||||
}
|
Reference in New Issue
Block a user