fix(event_handlers): DB dead lock (#21468)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -3,8 +3,10 @@ from .clean_when_document_deleted import handle
|
|||||||
from .create_document_index import handle
|
from .create_document_index import handle
|
||||||
from .create_installed_app_when_app_created import handle
|
from .create_installed_app_when_app_created import handle
|
||||||
from .create_site_record_when_app_created import handle
|
from .create_site_record_when_app_created import handle
|
||||||
from .deduct_quota_when_message_created import handle
|
|
||||||
from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
|
from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
|
||||||
from .update_app_dataset_join_when_app_model_config_updated import handle
|
from .update_app_dataset_join_when_app_model_config_updated import handle
|
||||||
from .update_app_dataset_join_when_app_published_workflow_updated import handle
|
from .update_app_dataset_join_when_app_published_workflow_updated import handle
|
||||||
from .update_provider_last_used_at_when_message_created import handle
|
|
||||||
|
# Consolidated handler replaces both deduct_quota_when_message_created and
|
||||||
|
# update_provider_last_used_at_when_message_created
|
||||||
|
from .update_provider_when_message_created import handle
|
||||||
|
@@ -1,65 +0,0 @@
|
|||||||
from datetime import UTC, datetime
|
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
|
||||||
from core.entities.provider_entities import QuotaUnit
|
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
|
||||||
from events.message_event import message_was_created
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.provider import Provider, ProviderType
|
|
||||||
|
|
||||||
|
|
||||||
@message_was_created.connect
|
|
||||||
def handle(sender, **kwargs):
|
|
||||||
message = sender
|
|
||||||
application_generate_entity = kwargs.get("application_generate_entity")
|
|
||||||
|
|
||||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
|
||||||
return
|
|
||||||
|
|
||||||
model_config = application_generate_entity.model_conf
|
|
||||||
provider_model_bundle = model_config.provider_model_bundle
|
|
||||||
provider_configuration = provider_model_bundle.configuration
|
|
||||||
|
|
||||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
|
||||||
return
|
|
||||||
|
|
||||||
system_configuration = provider_configuration.system_configuration
|
|
||||||
|
|
||||||
if not system_configuration.current_quota_type:
|
|
||||||
return
|
|
||||||
|
|
||||||
quota_unit = None
|
|
||||||
for quota_configuration in system_configuration.quota_configurations:
|
|
||||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
|
||||||
quota_unit = quota_configuration.quota_unit
|
|
||||||
|
|
||||||
if quota_configuration.quota_limit == -1:
|
|
||||||
return
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
used_quota = None
|
|
||||||
if quota_unit:
|
|
||||||
if quota_unit == QuotaUnit.TOKENS:
|
|
||||||
used_quota = message.message_tokens + message.answer_tokens
|
|
||||||
elif quota_unit == QuotaUnit.CREDITS:
|
|
||||||
used_quota = dify_config.get_model_credits(model_config.model)
|
|
||||||
else:
|
|
||||||
used_quota = 1
|
|
||||||
|
|
||||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
|
||||||
db.session.query(Provider).filter(
|
|
||||||
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
|
||||||
# TODO: Use provider name with prefix after the data migration.
|
|
||||||
Provider.provider_name == ModelProviderID(model_config.provider).provider_name,
|
|
||||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
|
||||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
|
||||||
Provider.quota_limit > Provider.quota_used,
|
|
||||||
).update(
|
|
||||||
{
|
|
||||||
"quota_used": Provider.quota_used + used_quota,
|
|
||||||
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
db.session.commit()
|
|
@@ -1,20 +0,0 @@
|
|||||||
from datetime import UTC, datetime
|
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
|
||||||
from events.message_event import message_was_created
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.provider import Provider
|
|
||||||
|
|
||||||
|
|
||||||
@message_was_created.connect
|
|
||||||
def handle(sender, **kwargs):
|
|
||||||
application_generate_entity = kwargs.get("application_generate_entity")
|
|
||||||
|
|
||||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
|
||||||
return
|
|
||||||
|
|
||||||
db.session.query(Provider).filter(
|
|
||||||
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
|
||||||
Provider.provider_name == application_generate_entity.model_conf.provider,
|
|
||||||
).update({"last_used": datetime.now(UTC).replace(tzinfo=None)})
|
|
||||||
db.session.commit()
|
|
@@ -0,0 +1,233 @@
|
|||||||
|
import logging
|
||||||
|
import time as time_module
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||||
|
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
|
||||||
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
|
from events.message_event import message_was_created
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs import datetime_utils
|
||||||
|
from models.model import Message
|
||||||
|
from models.provider import Provider, ProviderType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _ProviderUpdateFilters(BaseModel):
|
||||||
|
"""Filters for identifying Provider records to update."""
|
||||||
|
|
||||||
|
tenant_id: str
|
||||||
|
provider_name: str
|
||||||
|
provider_type: Optional[str] = None
|
||||||
|
quota_type: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class _ProviderUpdateAdditionalFilters(BaseModel):
|
||||||
|
"""Additional filters for Provider updates."""
|
||||||
|
|
||||||
|
quota_limit_check: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class _ProviderUpdateValues(BaseModel):
|
||||||
|
"""Values to update in Provider records."""
|
||||||
|
|
||||||
|
last_used: Optional[datetime] = None
|
||||||
|
quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
|
||||||
|
|
||||||
|
|
||||||
|
class _ProviderUpdateOperation(BaseModel):
|
||||||
|
"""A single Provider update operation."""
|
||||||
|
|
||||||
|
filters: _ProviderUpdateFilters
|
||||||
|
values: _ProviderUpdateValues
|
||||||
|
additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
|
||||||
|
description: str = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
@message_was_created.connect
|
||||||
|
def handle(sender: Message, **kwargs):
|
||||||
|
"""
|
||||||
|
Consolidated handler for Provider updates when a message is created.
|
||||||
|
|
||||||
|
This handler replaces both:
|
||||||
|
- update_provider_last_used_at_when_message_created
|
||||||
|
- deduct_quota_when_message_created
|
||||||
|
|
||||||
|
By performing all Provider updates in a single transaction, we ensure
|
||||||
|
consistency and efficiency when updating Provider records.
|
||||||
|
"""
|
||||||
|
message = sender
|
||||||
|
application_generate_entity = kwargs.get("application_generate_entity")
|
||||||
|
|
||||||
|
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||||
|
return
|
||||||
|
|
||||||
|
tenant_id = application_generate_entity.app_config.tenant_id
|
||||||
|
provider_name = application_generate_entity.model_conf.provider
|
||||||
|
current_time = datetime_utils.naive_utc_now()
|
||||||
|
|
||||||
|
# Prepare updates for both scenarios
|
||||||
|
updates_to_perform: list[_ProviderUpdateOperation] = []
|
||||||
|
|
||||||
|
# 1. Always update last_used for the provider
|
||||||
|
basic_update = _ProviderUpdateOperation(
|
||||||
|
filters=_ProviderUpdateFilters(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_name=provider_name,
|
||||||
|
),
|
||||||
|
values=_ProviderUpdateValues(last_used=current_time),
|
||||||
|
description="basic_last_used_update",
|
||||||
|
)
|
||||||
|
updates_to_perform.append(basic_update)
|
||||||
|
|
||||||
|
# 2. Check if we need to deduct quota (system provider only)
|
||||||
|
model_config = application_generate_entity.model_conf
|
||||||
|
provider_model_bundle = model_config.provider_model_bundle
|
||||||
|
provider_configuration = provider_model_bundle.configuration
|
||||||
|
|
||||||
|
if (
|
||||||
|
provider_configuration.using_provider_type == ProviderType.SYSTEM
|
||||||
|
and provider_configuration.system_configuration
|
||||||
|
and provider_configuration.system_configuration.current_quota_type is not None
|
||||||
|
):
|
||||||
|
system_configuration = provider_configuration.system_configuration
|
||||||
|
|
||||||
|
# Calculate quota usage
|
||||||
|
used_quota = _calculate_quota_usage(
|
||||||
|
message=message,
|
||||||
|
system_configuration=system_configuration,
|
||||||
|
model_name=model_config.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if used_quota is not None:
|
||||||
|
quota_update = _ProviderUpdateOperation(
|
||||||
|
filters=_ProviderUpdateFilters(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_name=ModelProviderID(model_config.provider).provider_name,
|
||||||
|
provider_type=ProviderType.SYSTEM.value,
|
||||||
|
quota_type=provider_configuration.system_configuration.current_quota_type.value,
|
||||||
|
),
|
||||||
|
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
|
||||||
|
additional_filters=_ProviderUpdateAdditionalFilters(
|
||||||
|
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
|
||||||
|
),
|
||||||
|
description="quota_deduction_update",
|
||||||
|
)
|
||||||
|
updates_to_perform.append(quota_update)
|
||||||
|
|
||||||
|
# Execute all updates
|
||||||
|
start_time = time_module.perf_counter()
|
||||||
|
try:
|
||||||
|
_execute_provider_updates(updates_to_perform)
|
||||||
|
|
||||||
|
# Log successful completion with timing
|
||||||
|
duration = time_module.perf_counter() - start_time
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Provider updates completed successfully. "
|
||||||
|
f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, "
|
||||||
|
f"Tenant: {tenant_id}, Provider: {provider_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log failure with timing and context
|
||||||
|
duration = time_module.perf_counter() - start_time
|
||||||
|
|
||||||
|
logger.exception(
|
||||||
|
f"Provider updates failed after {duration:.3f}s. "
|
||||||
|
f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, "
|
||||||
|
f"Provider: {provider_name}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_quota_usage(
|
||||||
|
*, message: Message, system_configuration: SystemConfiguration, model_name: str
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""Calculate quota usage based on message tokens and quota type."""
|
||||||
|
quota_unit = None
|
||||||
|
for quota_configuration in system_configuration.quota_configurations:
|
||||||
|
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||||
|
quota_unit = quota_configuration.quota_unit
|
||||||
|
if quota_configuration.quota_limit == -1:
|
||||||
|
return None
|
||||||
|
break
|
||||||
|
if quota_unit is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if quota_unit == QuotaUnit.TOKENS:
|
||||||
|
tokens = message.message_tokens + message.answer_tokens
|
||||||
|
return tokens
|
||||||
|
if quota_unit == QuotaUnit.CREDITS:
|
||||||
|
tokens = dify_config.get_model_credits(model_name)
|
||||||
|
return tokens
|
||||||
|
elif quota_unit == QuotaUnit.TIMES:
|
||||||
|
return 1
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to calculate quota usage")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
|
||||||
|
"""Execute all Provider updates in a single transaction."""
|
||||||
|
if not updates_to_perform:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use SQLAlchemy's context manager for transaction management
|
||||||
|
# This automatically handles commit/rollback
|
||||||
|
with db.session.begin():
|
||||||
|
# Use a single transaction for all updates
|
||||||
|
for update_operation in updates_to_perform:
|
||||||
|
filters = update_operation.filters
|
||||||
|
values = update_operation.values
|
||||||
|
additional_filters = update_operation.additional_filters
|
||||||
|
description = update_operation.description
|
||||||
|
|
||||||
|
# Build the where conditions
|
||||||
|
where_conditions = [
|
||||||
|
Provider.tenant_id == filters.tenant_id,
|
||||||
|
Provider.provider_name == filters.provider_name,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add additional filters if specified
|
||||||
|
if filters.provider_type is not None:
|
||||||
|
where_conditions.append(Provider.provider_type == filters.provider_type)
|
||||||
|
if filters.quota_type is not None:
|
||||||
|
where_conditions.append(Provider.quota_type == filters.quota_type)
|
||||||
|
if additional_filters.quota_limit_check:
|
||||||
|
where_conditions.append(Provider.quota_limit > Provider.quota_used)
|
||||||
|
|
||||||
|
# Prepare values dict for SQLAlchemy update
|
||||||
|
update_values = {}
|
||||||
|
if values.last_used is not None:
|
||||||
|
update_values["last_used"] = values.last_used
|
||||||
|
if values.quota_used is not None:
|
||||||
|
update_values["quota_used"] = values.quota_used
|
||||||
|
|
||||||
|
# Build and execute the update statement
|
||||||
|
stmt = update(Provider).where(*where_conditions).values(**update_values)
|
||||||
|
result = db.session.execute(stmt)
|
||||||
|
rows_affected = result.rowcount
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Provider update ({description}): {rows_affected} rows affected. "
|
||||||
|
f"Filters: {filters.model_dump()}, Values: {update_values}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no rows were affected for quota updates, log a warning
|
||||||
|
if rows_affected == 0 and description == "quota_deduction_update":
|
||||||
|
logger.warning(
|
||||||
|
f"No Provider rows updated for quota deduction. "
|
||||||
|
f"This may indicate quota limit exceeded or provider not found. "
|
||||||
|
f"Filters: {filters.model_dump()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates")
|
@@ -914,11 +914,11 @@ class Message(Base):
|
|||||||
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
|
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
|
||||||
query: Mapped[str] = db.Column(db.Text, nullable=False)
|
query: Mapped[str] = db.Column(db.Text, nullable=False)
|
||||||
message = db.Column(db.JSON, nullable=False)
|
message = db.Column(db.JSON, nullable=False)
|
||||||
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
message_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||||
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||||
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||||
answer: Mapped[str] = db.Column(db.Text, nullable=False)
|
answer: Mapped[str] = db.Column(db.Text, nullable=False)
|
||||||
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
answer_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||||
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||||
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||||
parent_message_id = db.Column(StringUUID, nullable=True)
|
parent_message_id = db.Column(StringUUID, nullable=True)
|
||||||
|
@@ -155,6 +155,7 @@ dev = [
|
|||||||
"types_setuptools>=80.9.0",
|
"types_setuptools>=80.9.0",
|
||||||
"pandas-stubs~=2.2.3",
|
"pandas-stubs~=2.2.3",
|
||||||
"scipy-stubs>=1.15.3.0",
|
"scipy-stubs>=1.15.3.0",
|
||||||
|
"types-python-http-client>=3.3.7.20240910",
|
||||||
]
|
]
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
|
4269
api/uv.lock
generated
4269
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,248 @@
|
|||||||
|
import threading
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||||
|
from core.entities.provider_entities import QuotaUnit
|
||||||
|
from events.event_handlers.update_provider_when_message_created import (
|
||||||
|
handle,
|
||||||
|
get_update_stats,
|
||||||
|
)
|
||||||
|
from models.provider import ProviderType
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderUpdateDeadlockPrevention:
|
||||||
|
"""Test suite for deadlock prevention in Provider updates."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Setup test fixtures."""
|
||||||
|
self.mock_message = Mock()
|
||||||
|
self.mock_message.answer_tokens = 100
|
||||||
|
|
||||||
|
self.mock_app_config = Mock()
|
||||||
|
self.mock_app_config.tenant_id = "test-tenant-123"
|
||||||
|
|
||||||
|
self.mock_model_conf = Mock()
|
||||||
|
self.mock_model_conf.provider = "openai"
|
||||||
|
|
||||||
|
self.mock_system_config = Mock()
|
||||||
|
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
|
||||||
|
|
||||||
|
self.mock_provider_config = Mock()
|
||||||
|
self.mock_provider_config.using_provider_type = ProviderType.SYSTEM
|
||||||
|
self.mock_provider_config.system_configuration = self.mock_system_config
|
||||||
|
|
||||||
|
self.mock_provider_bundle = Mock()
|
||||||
|
self.mock_provider_bundle.configuration = self.mock_provider_config
|
||||||
|
|
||||||
|
self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle
|
||||||
|
|
||||||
|
self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity)
|
||||||
|
self.mock_generate_entity.app_config = self.mock_app_config
|
||||||
|
self.mock_generate_entity.model_conf = self.mock_model_conf
|
||||||
|
|
||||||
|
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||||
|
def test_consolidated_handler_basic_functionality(self, mock_db):
|
||||||
|
"""Test that the consolidated handler performs both updates correctly."""
|
||||||
|
# Setup mock query chain
|
||||||
|
mock_query = Mock()
|
||||||
|
mock_db.session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
mock_query.order_by.return_value = mock_query
|
||||||
|
mock_query.update.return_value = 1 # 1 row affected
|
||||||
|
|
||||||
|
# Call the handler
|
||||||
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||||
|
|
||||||
|
# Verify db.session.query was called
|
||||||
|
assert mock_db.session.query.called
|
||||||
|
|
||||||
|
# Verify commit was called
|
||||||
|
mock_db.session.commit.assert_called_once()
|
||||||
|
|
||||||
|
# Verify no rollback was called
|
||||||
|
assert not mock_db.session.rollback.called
|
||||||
|
|
||||||
|
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||||
|
def test_deadlock_retry_mechanism(self, mock_db):
|
||||||
|
"""Test that deadlock errors trigger retry logic."""
|
||||||
|
# Setup mock to raise deadlock error on first attempt, succeed on second
|
||||||
|
mock_query = Mock()
|
||||||
|
mock_db.session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
mock_query.order_by.return_value = mock_query
|
||||||
|
mock_query.update.return_value = 1
|
||||||
|
|
||||||
|
# First call raises deadlock, second succeeds
|
||||||
|
mock_db.session.commit.side_effect = [
|
||||||
|
OperationalError("deadlock detected", None, None),
|
||||||
|
None, # Success on retry
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the handler
|
||||||
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||||
|
|
||||||
|
# Verify commit was called twice (original + retry)
|
||||||
|
assert mock_db.session.commit.call_count == 2
|
||||||
|
|
||||||
|
# Verify rollback was called once (after first failure)
|
||||||
|
mock_db.session.rollback.assert_called_once()
|
||||||
|
|
||||||
|
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||||
|
@patch("events.event_handlers.update_provider_when_message_created.time.sleep")
|
||||||
|
def test_exponential_backoff_timing(self, mock_sleep, mock_db):
|
||||||
|
"""Test that retry delays follow exponential backoff pattern."""
|
||||||
|
# Setup mock to fail twice, succeed on third attempt
|
||||||
|
mock_query = Mock()
|
||||||
|
mock_db.session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
mock_query.order_by.return_value = mock_query
|
||||||
|
mock_query.update.return_value = 1
|
||||||
|
|
||||||
|
mock_db.session.commit.side_effect = [
|
||||||
|
OperationalError("deadlock detected", None, None),
|
||||||
|
OperationalError("deadlock detected", None, None),
|
||||||
|
None, # Success on third attempt
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the handler
|
||||||
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||||
|
|
||||||
|
# Verify sleep was called twice with increasing delays
|
||||||
|
assert mock_sleep.call_count == 2
|
||||||
|
|
||||||
|
# First delay should be around 0.1s + jitter
|
||||||
|
first_delay = mock_sleep.call_args_list[0][0][0]
|
||||||
|
assert 0.1 <= first_delay <= 0.3
|
||||||
|
|
||||||
|
# Second delay should be around 0.2s + jitter
|
||||||
|
second_delay = mock_sleep.call_args_list[1][0][0]
|
||||||
|
assert 0.2 <= second_delay <= 0.4
|
||||||
|
|
||||||
|
def test_concurrent_handler_execution(self):
|
||||||
|
"""Test that multiple handlers can run concurrently without deadlock."""
|
||||||
|
results = []
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def run_handler():
|
||||||
|
try:
|
||||||
|
with patch(
|
||||||
|
"events.event_handlers.update_provider_when_message_created.db"
|
||||||
|
) as mock_db:
|
||||||
|
mock_query = Mock()
|
||||||
|
mock_db.session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
mock_query.order_by.return_value = mock_query
|
||||||
|
mock_query.update.return_value = 1
|
||||||
|
|
||||||
|
handle(
|
||||||
|
self.mock_message,
|
||||||
|
application_generate_entity=self.mock_generate_entity,
|
||||||
|
)
|
||||||
|
results.append("success")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(str(e))
|
||||||
|
|
||||||
|
# Run multiple handlers concurrently
|
||||||
|
threads = []
|
||||||
|
for _ in range(5):
|
||||||
|
thread = threading.Thread(target=run_handler)
|
||||||
|
threads.append(thread)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# Wait for all threads to complete
|
||||||
|
for thread in threads:
|
||||||
|
thread.join(timeout=5)
|
||||||
|
|
||||||
|
# Verify all handlers completed successfully
|
||||||
|
assert len(results) == 5
|
||||||
|
assert len(errors) == 0
|
||||||
|
|
||||||
|
def test_performance_stats_tracking(self):
|
||||||
|
"""Test that performance statistics are tracked correctly."""
|
||||||
|
# Reset stats
|
||||||
|
stats = get_update_stats()
|
||||||
|
initial_total = stats["total_updates"]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"events.event_handlers.update_provider_when_message_created.db"
|
||||||
|
) as mock_db:
|
||||||
|
mock_query = Mock()
|
||||||
|
mock_db.session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
mock_query.order_by.return_value = mock_query
|
||||||
|
mock_query.update.return_value = 1
|
||||||
|
|
||||||
|
# Call handler
|
||||||
|
handle(
|
||||||
|
self.mock_message, application_generate_entity=self.mock_generate_entity
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that stats were updated
|
||||||
|
updated_stats = get_update_stats()
|
||||||
|
assert updated_stats["total_updates"] == initial_total + 1
|
||||||
|
assert updated_stats["successful_updates"] >= initial_total + 1
|
||||||
|
|
||||||
|
def test_non_chat_entity_ignored(self):
|
||||||
|
"""Test that non-chat entities are ignored by the handler."""
|
||||||
|
# Create a non-chat entity
|
||||||
|
mock_non_chat_entity = Mock()
|
||||||
|
mock_non_chat_entity.__class__.__name__ = "NonChatEntity"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"events.event_handlers.update_provider_when_message_created.db"
|
||||||
|
) as mock_db:
|
||||||
|
# Call handler with non-chat entity
|
||||||
|
handle(self.mock_message, application_generate_entity=mock_non_chat_entity)
|
||||||
|
|
||||||
|
# Verify no database operations were performed
|
||||||
|
assert not mock_db.session.query.called
|
||||||
|
assert not mock_db.session.commit.called
|
||||||
|
|
||||||
|
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||||
|
def test_quota_calculation_tokens(self, mock_db):
|
||||||
|
"""Test quota calculation for token-based quotas."""
|
||||||
|
# Setup token-based quota
|
||||||
|
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
|
||||||
|
self.mock_message.answer_tokens = 150
|
||||||
|
|
||||||
|
mock_query = Mock()
|
||||||
|
mock_db.session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
mock_query.order_by.return_value = mock_query
|
||||||
|
mock_query.update.return_value = 1
|
||||||
|
|
||||||
|
# Call handler
|
||||||
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||||
|
|
||||||
|
# Verify update was called with token count
|
||||||
|
update_calls = mock_query.update.call_args_list
|
||||||
|
|
||||||
|
# Should have at least one call with quota_used update
|
||||||
|
quota_update_found = False
|
||||||
|
for call in update_calls:
|
||||||
|
values = call[0][0] # First argument to update()
|
||||||
|
if "quota_used" in values:
|
||||||
|
quota_update_found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
assert quota_update_found
|
||||||
|
|
||||||
|
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||||
|
def test_quota_calculation_times(self, mock_db):
|
||||||
|
"""Test quota calculation for times-based quotas."""
|
||||||
|
# Setup times-based quota
|
||||||
|
self.mock_system_config.current_quota_type = QuotaUnit.TIMES
|
||||||
|
|
||||||
|
mock_query = Mock()
|
||||||
|
mock_db.session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
mock_query.order_by.return_value = mock_query
|
||||||
|
mock_query.update.return_value = 1
|
||||||
|
|
||||||
|
# Call handler
|
||||||
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||||
|
|
||||||
|
# Verify update was called
|
||||||
|
assert mock_query.update.called
|
||||||
|
assert mock_db.session.commit.called
|
Reference in New Issue
Block a user