From 6900b0813486cd82f0e71541883cf285b12b8e12 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 9 Aug 2025 22:42:18 +0800 Subject: [PATCH] fix: sync missing conversation variables for existing conversations (#23649) --- api/core/app/apps/advanced_chat/app_runner.py | 119 ++++- .../test_app_runner_conversation_variables.py | 419 ++++++++++++++++++ 2 files changed, 518 insertions(+), 20 deletions(-) create mode 100644 api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index a75e17af6..3de2f5ca9 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -118,26 +118,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ): return - # Init conversation variables - stmt = select(ConversationVariable).where( - ConversationVariable.app_id == self.conversation.app_id, - ConversationVariable.conversation_id == self.conversation.id, - ) - with Session(db.engine) as session: - db_conversation_variables = session.scalars(stmt).all() - if not db_conversation_variables: - # Create conversation variables if they don't exist. - db_conversation_variables = [ - ConversationVariable.from_variable( - app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable - ) - for variable in self._workflow.conversation_variables - ] - session.add_all(db_conversation_variables) - # Convert database entities to variables. - conversation_variables = [item.to_variable() for item in db_conversation_variables] - - session.commit() + # Initialize conversation variables + conversation_variables = self._initialize_conversation_variables() # Create a variable pool. system_inputs = SystemVariable( @@ -292,3 +274,100 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): message_id=message_id, trace_manager=app_generate_entity.trace_manager, ) + + def _initialize_conversation_variables(self) -> list[VariableUnion]: + """ + Initialize conversation variables for the current conversation. + + This method: + 1. Loads existing variables from the database + 2. Creates new variables if none exist + 3. Syncs missing variables from the workflow definition + + :return: List of conversation variables ready for use + """ + with Session(db.engine) as session: + existing_variables = self._load_existing_conversation_variables(session) + + if not existing_variables: + # First time initialization - create all variables + existing_variables = self._create_all_conversation_variables(session) + else: + # Check and add any missing variables from the workflow + existing_variables = self._sync_missing_conversation_variables(session, existing_variables) + + # Convert to Variable objects for use in the workflow + conversation_variables = [var.to_variable() for var in existing_variables] + + session.commit() + return cast(list[VariableUnion], conversation_variables) + + def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]: + """ + Load existing conversation variables from the database. + + :param session: Database session + :return: List of existing conversation variables + """ + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == self.conversation.app_id, + ConversationVariable.conversation_id == self.conversation.id, + ) + return list(session.scalars(stmt).all()) + + def _create_all_conversation_variables(self, session: Session) -> list[ConversationVariable]: + """ + Create all conversation variables for a new conversation. + + :param session: Database session + :return: List of created conversation variables + """ + new_variables = [ + ConversationVariable.from_variable( + app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable + ) + for variable in self._workflow.conversation_variables + ] + + if new_variables: + session.add_all(new_variables) + + return new_variables + + def _sync_missing_conversation_variables( + self, session: Session, existing_variables: list[ConversationVariable] + ) -> list[ConversationVariable]: + """ + Sync missing conversation variables from the workflow definition. + + This handles the case where new variables are added to a workflow + after conversations have already been created. + + :param session: Database session + :param existing_variables: List of existing conversation variables + :return: Updated list including any newly created variables + """ + # Get IDs of existing and workflow variables + existing_ids = {var.id for var in existing_variables} + workflow_variables = {var.id: var for var in self._workflow.conversation_variables} + + # Find missing variable IDs + missing_ids = set(workflow_variables.keys()) - existing_ids + + if not missing_ids: + return existing_variables + + # Create missing variables with their default values + new_variables = [ + ConversationVariable.from_variable( + app_id=self.conversation.app_id, + conversation_id=self.conversation.id, + variable=workflow_variables[var_id], + ) + for var_id in missing_ids + ] + + session.add_all(new_variables) + + # Return combined list + return existing_variables + new_variables diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py new file mode 100644 index 000000000..da175e7cc --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -0,0 +1,419 @@ +"""Test conversation variable handling in AdvancedChatAppRunner.""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +from sqlalchemy.orm import Session + +from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.variables import SegmentType +from factories import variable_factory +from models import ConversationVariable, Workflow + + +class TestAdvancedChatAppRunnerConversationVariables: + """Test that AdvancedChatAppRunner correctly handles conversation variables.""" + + def test_missing_conversation_variables_are_added(self): + """Test that new conversation variables added to workflow are created for existing conversations.""" + # Setup + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Create workflow with two conversation variables + workflow_vars = [ + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var1", + "name": "existing_var", + "value_type": SegmentType.STRING, + "value": "default1", + } + ), + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var2", + "name": "new_var", + "value_type": SegmentType.STRING, + "value": "default2", + } + ), + ] + + # Mock workflow with conversation variables + mock_workflow = MagicMock(spec=Workflow) + mock_workflow.conversation_variables = workflow_vars + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.id = workflow_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + # Create existing conversation variable (only var1 exists in DB) + existing_db_var = MagicMock(spec=ConversationVariable) + existing_db_var.id = "var1" + existing_db_var.app_id = app_id + existing_db_var.conversation_id = conversation_id + existing_db_var.to_variable = MagicMock(return_value=workflow_vars[0]) + + # Mock conversation and message + mock_conversation = MagicMock() + mock_conversation.app_id = app_id + mock_conversation.id = conversation_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + # Mock app config + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + # Mock app generate entity + mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity) + mock_app_generate_entity.app_config = mock_app_config + mock_app_generate_entity.inputs = {} + mock_app_generate_entity.query = "test query" + mock_app_generate_entity.files = [] + mock_app_generate_entity.user_id = str(uuid4()) + mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API + mock_app_generate_entity.workflow_run_id = str(uuid4()) + mock_app_generate_entity.call_depth = 0 + mock_app_generate_entity.single_iteration_run = None + mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.trace_manager = None + + # Create runner + runner = AdvancedChatAppRunner( + application_generate_entity=mock_app_generate_entity, + queue_manager=MagicMock(), + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + ) + + # Mock database session + mock_session = MagicMock(spec=Session) + + # First query returns only existing variable + mock_scalars_result = MagicMock() + mock_scalars_result.all.return_value = [existing_db_var] + mock_session.scalars.return_value = mock_scalars_result + + # Track what gets added to session + added_items = [] + + def track_add_all(items): + added_items.extend(items) + + mock_session.add_all.side_effect = track_add_all + + # Patch the necessary components + with ( + patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, + patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_input_moderation", return_value=False), + patch.object(runner, "handle_annotation_reply", return_value=False), + patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, + patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class, + ): + # Setup mocks + mock_session_class.return_value.__enter__.return_value = mock_session + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists + mock_db.engine = MagicMock() + + # Mock graph initialization + mock_init_graph.return_value = MagicMock() + + # Mock workflow entry + mock_workflow_entry = MagicMock() + mock_workflow_entry.run.return_value = iter([]) # Empty generator + mock_workflow_entry_class.return_value = mock_workflow_entry + + # Run the method + runner.run() + + # Verify that the missing variable was added + assert len(added_items) == 1, "Should have added exactly one missing variable" + + # Check that the added item is the missing variable (var2) + added_var = added_items[0] + assert hasattr(added_var, "id"), "Added item should be a ConversationVariable" + # Note: Since we're mocking ConversationVariable.from_variable, + # we can't directly check the id, but we can verify add_all was called + assert mock_session.add_all.called, "Session add_all should have been called" + assert mock_session.commit.called, "Session commit should have been called" + + def test_no_variables_creates_all(self): + """Test that all conversation variables are created when none exist in DB.""" + # Setup + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Create workflow with conversation variables + workflow_vars = [ + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var1", + "name": "var1", + "value_type": SegmentType.STRING, + "value": "default1", + } + ), + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var2", + "name": "var2", + "value_type": SegmentType.STRING, + "value": "default2", + } + ), + ] + + # Mock workflow + mock_workflow = MagicMock(spec=Workflow) + mock_workflow.conversation_variables = workflow_vars + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.id = workflow_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + # Mock conversation and message + mock_conversation = MagicMock() + mock_conversation.app_id = app_id + mock_conversation.id = conversation_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + # Mock app config + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + # Mock app generate entity + mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity) + mock_app_generate_entity.app_config = mock_app_config + mock_app_generate_entity.inputs = {} + mock_app_generate_entity.query = "test query" + mock_app_generate_entity.files = [] + mock_app_generate_entity.user_id = str(uuid4()) + mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API + mock_app_generate_entity.workflow_run_id = str(uuid4()) + mock_app_generate_entity.call_depth = 0 + mock_app_generate_entity.single_iteration_run = None + mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.trace_manager = None + + # Create runner + runner = AdvancedChatAppRunner( + application_generate_entity=mock_app_generate_entity, + queue_manager=MagicMock(), + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + ) + + # Mock database session + mock_session = MagicMock(spec=Session) + + # Query returns empty list (no existing variables) + mock_scalars_result = MagicMock() + mock_scalars_result.all.return_value = [] + mock_session.scalars.return_value = mock_scalars_result + + # Track what gets added to session + added_items = [] + + def track_add_all(items): + added_items.extend(items) + + mock_session.add_all.side_effect = track_add_all + + # Patch the necessary components + with ( + patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, + patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_input_moderation", return_value=False), + patch.object(runner, "handle_annotation_reply", return_value=False), + patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, + patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class, + patch("core.app.apps.advanced_chat.app_runner.ConversationVariable") as mock_conv_var_class, + ): + # Setup mocks + mock_session_class.return_value.__enter__.return_value = mock_session + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists + mock_db.engine = MagicMock() + + # Mock ConversationVariable.from_variable to return mock objects + mock_conv_vars = [] + for var in workflow_vars: + mock_cv = MagicMock() + mock_cv.id = var.id + mock_cv.to_variable.return_value = var + mock_conv_vars.append(mock_cv) + + mock_conv_var_class.from_variable.side_effect = mock_conv_vars + + # Mock graph initialization + mock_init_graph.return_value = MagicMock() + + # Mock workflow entry + mock_workflow_entry = MagicMock() + mock_workflow_entry.run.return_value = iter([]) # Empty generator + mock_workflow_entry_class.return_value = mock_workflow_entry + + # Run the method + runner.run() + + # Verify that all variables were created + assert len(added_items) == 2, "Should have added both variables" + assert mock_session.add_all.called, "Session add_all should have been called" + assert mock_session.commit.called, "Session commit should have been called" + + def test_all_variables_exist_no_changes(self): + """Test that no changes are made when all variables already exist in DB.""" + # Setup + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Create workflow with conversation variables + workflow_vars = [ + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var1", + "name": "var1", + "value_type": SegmentType.STRING, + "value": "default1", + } + ), + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var2", + "name": "var2", + "value_type": SegmentType.STRING, + "value": "default2", + } + ), + ] + + # Mock workflow + mock_workflow = MagicMock(spec=Workflow) + mock_workflow.conversation_variables = workflow_vars + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.id = workflow_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + # Create existing conversation variables (both exist in DB) + existing_db_vars = [] + for var in workflow_vars: + db_var = MagicMock(spec=ConversationVariable) + db_var.id = var.id + db_var.app_id = app_id + db_var.conversation_id = conversation_id + db_var.to_variable = MagicMock(return_value=var) + existing_db_vars.append(db_var) + + # Mock conversation and message + mock_conversation = MagicMock() + mock_conversation.app_id = app_id + mock_conversation.id = conversation_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + # Mock app config + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + # Mock app generate entity + mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity) + mock_app_generate_entity.app_config = mock_app_config + mock_app_generate_entity.inputs = {} + mock_app_generate_entity.query = "test query" + mock_app_generate_entity.files = [] + mock_app_generate_entity.user_id = str(uuid4()) + mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API + mock_app_generate_entity.workflow_run_id = str(uuid4()) + mock_app_generate_entity.call_depth = 0 + mock_app_generate_entity.single_iteration_run = None + mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.trace_manager = None + + # Create runner + runner = AdvancedChatAppRunner( + application_generate_entity=mock_app_generate_entity, + queue_manager=MagicMock(), + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + ) + + # Mock database session + mock_session = MagicMock(spec=Session) + + # Query returns all existing variables + mock_scalars_result = MagicMock() + mock_scalars_result.all.return_value = existing_db_vars + mock_session.scalars.return_value = mock_scalars_result + + # Patch the necessary components + with ( + patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, + patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_input_moderation", return_value=False), + patch.object(runner, "handle_annotation_reply", return_value=False), + patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, + patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class, + ): + # Setup mocks + mock_session_class.return_value.__enter__.return_value = mock_session + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists + mock_db.engine = MagicMock() + + # Mock graph initialization + mock_init_graph.return_value = MagicMock() + + # Mock workflow entry + mock_workflow_entry = MagicMock() + mock_workflow_entry.run.return_value = iter([]) # Empty generator + mock_workflow_entry_class.return_value = mock_workflow_entry + + # Run the method + runner.run() + + # Verify that no variables were added + assert not mock_session.add_all.called, "Session add_all should not have been called" + assert mock_session.commit.called, "Session commit should still be called"