From 94f8e48647877f6fe295b6f4bfca40e291319178 Mon Sep 17 00:00:00 2001 From: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Wed, 25 Jun 2025 11:44:35 +0800 Subject: [PATCH] Refactor update dataset (fix #21401) (#21402) Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/services/dataset_service.py | 460 ++++++---- ...t_service_batch_update_document_status.py} | 0 .../test_dataset_service_update_dataset.py | 826 ++++++++++++++++++ 3 files changed, 1133 insertions(+), 153 deletions(-) rename api/tests/unit_tests/services/{test_dataset_service.py => test_dataset_service_batch_update_document_status.py} (100%) create mode 100644 api/tests/unit_tests/services/test_dataset_service_update_dataset.py diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 49ca98624..af1210c19 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -280,174 +280,328 @@ class DatasetService: @staticmethod def update_dataset(dataset_id, data, user): + """ + Update dataset configuration and settings. + + Args: + dataset_id: The unique identifier of the dataset to update + data: Dictionary containing the update data + user: The user performing the update operation + + Returns: + Dataset: The updated dataset object + + Raises: + ValueError: If dataset not found or validation fails + NoPermissionError: If user lacks permission to update the dataset + """ + # Retrieve and validate dataset existence dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise ValueError("Dataset not found") + # Verify user has permission to update this dataset DatasetService.check_dataset_permission(dataset, user) + + # Handle external dataset updates if dataset.provider == "external": - external_retrieval_model = data.get("external_retrieval_model", None) - if external_retrieval_model: - dataset.retrieval_model = external_retrieval_model - dataset.name = data.get("name", dataset.name) - dataset.description = data.get("description", "") - permission = data.get("permission") - if permission: - dataset.permission = permission - external_knowledge_id = data.get("external_knowledge_id", None) - db.session.add(dataset) - if not external_knowledge_id: - raise ValueError("External knowledge id is required.") - external_knowledge_api_id = data.get("external_knowledge_api_id", None) - if not external_knowledge_api_id: - raise ValueError("External knowledge api id is required.") - - with Session(db.engine) as session: - external_knowledge_binding = ( - session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() - ) - - if not external_knowledge_binding: - raise ValueError("External knowledge binding not found.") - - if ( - external_knowledge_binding.external_knowledge_id != external_knowledge_id - or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id - ): - external_knowledge_binding.external_knowledge_id = external_knowledge_id - external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id - db.session.add(external_knowledge_binding) - db.session.commit() + return DatasetService._update_external_dataset(dataset, data, user) else: - data.pop("partial_member_list", None) - data.pop("external_knowledge_api_id", None) - data.pop("external_knowledge_id", None) - data.pop("external_retrieval_model", None) - filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} - action = None - if dataset.indexing_technique != data["indexing_technique"]: - # if update indexing_technique - if data["indexing_technique"] == "economy": - action = "remove" - filtered_data["embedding_model"] = None - filtered_data["embedding_model_provider"] = None - filtered_data["collection_binding_id"] = None - elif data["indexing_technique"] == "high_quality": - action = "add" - # get embedding model setting - try: - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=data["embedding_model_provider"], - model_type=ModelType.TEXT_EMBEDDING, - model=data["embedding_model"], - ) - filtered_data["embedding_model"] = embedding_model.model - filtered_data["embedding_model_provider"] = embedding_model.provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model - ) - filtered_data["collection_binding_id"] = dataset_collection_binding.id - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) - else: - # add default plugin id to both setting sets, to make sure the plugin model provider is consistent - # Skip embedding model checks if not provided in the update request - if ( - "embedding_model_provider" not in data - or "embedding_model" not in data - or not data.get("embedding_model_provider") - or not data.get("embedding_model") - ): - # If the dataset already has embedding model settings, use those - if dataset.embedding_model_provider and dataset.embedding_model: - # Keep existing values - filtered_data["embedding_model_provider"] = dataset.embedding_model_provider - filtered_data["embedding_model"] = dataset.embedding_model - # If collection_binding_id exists, keep it too - if dataset.collection_binding_id: - filtered_data["collection_binding_id"] = dataset.collection_binding_id - # Otherwise, don't try to update embedding model settings at all - # Remove these fields from filtered_data if they exist but are None/empty - if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]: - del filtered_data["embedding_model_provider"] - if "embedding_model" in filtered_data and not filtered_data["embedding_model"]: - del filtered_data["embedding_model"] - else: - skip_embedding_update = False - try: - # Handle existing model provider - plugin_model_provider = dataset.embedding_model_provider - plugin_model_provider_str = None - if plugin_model_provider: - plugin_model_provider_str = str(ModelProviderID(plugin_model_provider)) + return DatasetService._update_internal_dataset(dataset, data, user) - # Handle new model provider from request - new_plugin_model_provider = data["embedding_model_provider"] - new_plugin_model_provider_str = None - if new_plugin_model_provider: - new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider)) + @staticmethod + def _update_external_dataset(dataset, data, user): + """ + Update external dataset configuration. - # Only update embedding model if both values are provided and different from current - if ( - plugin_model_provider_str != new_plugin_model_provider_str - or data["embedding_model"] != dataset.embedding_model - ): - action = "update" - model_manager = ModelManager() - try: - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=data["embedding_model_provider"], - model_type=ModelType.TEXT_EMBEDDING, - model=data["embedding_model"], - ) - except ProviderTokenNotInitError: - # If we can't get the embedding model, skip updating it - # and keep the existing settings if available - if dataset.embedding_model_provider and dataset.embedding_model: - filtered_data["embedding_model_provider"] = dataset.embedding_model_provider - filtered_data["embedding_model"] = dataset.embedding_model - if dataset.collection_binding_id: - filtered_data["collection_binding_id"] = dataset.collection_binding_id - # Skip the rest of the embedding model update - skip_embedding_update = True - if not skip_embedding_update: - filtered_data["embedding_model"] = embedding_model.model - filtered_data["embedding_model_provider"] = embedding_model.provider - dataset_collection_binding = ( - DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model - ) - ) - filtered_data["collection_binding_id"] = dataset_collection_binding.id - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) + Args: + dataset: The dataset object to update + data: Update data dictionary + user: User performing the update - filtered_data["updated_by"] = user.id - filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + Returns: + Dataset: Updated dataset object + """ + # Update retrieval model if provided + external_retrieval_model = data.get("external_retrieval_model", None) + if external_retrieval_model: + dataset.retrieval_model = external_retrieval_model - # update Retrieval model - filtered_data["retrieval_model"] = data["retrieval_model"] + # Update basic dataset properties + dataset.name = data.get("name", dataset.name) + dataset.description = data.get("description", dataset.description) - db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data) + # Update permission if provided + permission = data.get("permission") + if permission: + dataset.permission = permission + + # Validate and update external knowledge configuration + external_knowledge_id = data.get("external_knowledge_id", None) + external_knowledge_api_id = data.get("external_knowledge_api_id", None) + + if not external_knowledge_id: + raise ValueError("External knowledge id is required.") + if not external_knowledge_api_id: + raise ValueError("External knowledge api id is required.") + # Update metadata fields + dataset.updated_by = user.id if user else None + dataset.updated_at = datetime.datetime.utcnow() + db.session.add(dataset) + + # Update external knowledge binding + DatasetService._update_external_knowledge_binding(dataset.id, external_knowledge_id, external_knowledge_api_id) + + # Commit changes to database + db.session.commit() - db.session.commit() - if action: - deal_dataset_vector_index_task.delay(dataset_id, action) return dataset + @staticmethod + def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id): + """ + Update external knowledge binding configuration. + + Args: + dataset_id: Dataset identifier + external_knowledge_id: External knowledge identifier + external_knowledge_api_id: External knowledge API identifier + """ + with Session(db.engine) as session: + external_knowledge_binding = ( + session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() + ) + + if not external_knowledge_binding: + raise ValueError("External knowledge binding not found.") + + # Update binding if values have changed + if ( + external_knowledge_binding.external_knowledge_id != external_knowledge_id + or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id + ): + external_knowledge_binding.external_knowledge_id = external_knowledge_id + external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id + db.session.add(external_knowledge_binding) + + @staticmethod + def _update_internal_dataset(dataset, data, user): + """ + Update internal dataset configuration. + + Args: + dataset: The dataset object to update + data: Update data dictionary + user: User performing the update + + Returns: + Dataset: Updated dataset object + """ + # Remove external-specific fields from update data + data.pop("partial_member_list", None) + data.pop("external_knowledge_api_id", None) + data.pop("external_knowledge_id", None) + data.pop("external_retrieval_model", None) + + # Filter out None values except for description field + filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} + + # Handle indexing technique changes and embedding model updates + action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data) + + # Add metadata fields + filtered_data["updated_by"] = user.id + filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + # update Retrieval model + filtered_data["retrieval_model"] = data["retrieval_model"] + + # Update dataset in database + db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) + db.session.commit() + + # Trigger vector index task if indexing technique changed + if action: + deal_dataset_vector_index_task.delay(dataset.id, action) + + return dataset + + @staticmethod + def _handle_indexing_technique_change(dataset, data, filtered_data): + """ + Handle changes in indexing technique and configure embedding models accordingly. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data + + Returns: + str: Action to perform ('add', 'remove', 'update', or None) + """ + if dataset.indexing_technique != data["indexing_technique"]: + if data["indexing_technique"] == "economy": + # Remove embedding model configuration for economy mode + filtered_data["embedding_model"] = None + filtered_data["embedding_model_provider"] = None + filtered_data["collection_binding_id"] = None + return "remove" + elif data["indexing_technique"] == "high_quality": + # Configure embedding model for high quality mode + DatasetService._configure_embedding_model_for_high_quality(data, filtered_data) + return "add" + else: + # Handle embedding model updates when indexing technique remains the same + return DatasetService._handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data) + return None + + @staticmethod + def _configure_embedding_model_for_high_quality(data, filtered_data): + """ + Configure embedding model settings for high quality indexing. + + Args: + data: Update data dictionary + filtered_data: Filtered update data to modify + """ + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + + @staticmethod + def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data): + """ + Handle embedding model updates when indexing technique remains the same. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + + Returns: + str: Action to perform ('update' or None) + """ + # Skip embedding model checks if not provided in the update request + if ( + "embedding_model_provider" not in data + or "embedding_model" not in data + or not data.get("embedding_model_provider") + or not data.get("embedding_model") + ): + DatasetService._preserve_existing_embedding_settings(dataset, filtered_data) + return None + else: + return DatasetService._update_embedding_model_settings(dataset, data, filtered_data) + + @staticmethod + def _preserve_existing_embedding_settings(dataset, filtered_data): + """ + Preserve existing embedding model settings when not provided in update. + + Args: + dataset: Current dataset object + filtered_data: Filtered update data to modify + """ + # If the dataset already has embedding model settings, use those + if dataset.embedding_model_provider and dataset.embedding_model: + filtered_data["embedding_model_provider"] = dataset.embedding_model_provider + filtered_data["embedding_model"] = dataset.embedding_model + # If collection_binding_id exists, keep it too + if dataset.collection_binding_id: + filtered_data["collection_binding_id"] = dataset.collection_binding_id + # Otherwise, don't try to update embedding model settings at all + # Remove these fields from filtered_data if they exist but are None/empty + if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]: + del filtered_data["embedding_model_provider"] + if "embedding_model" in filtered_data and not filtered_data["embedding_model"]: + del filtered_data["embedding_model"] + + @staticmethod + def _update_embedding_model_settings(dataset, data, filtered_data): + """ + Update embedding model settings with new values. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + + Returns: + str: Action to perform ('update' or None) + """ + try: + # Compare current and new model provider settings + current_provider_str = ( + str(ModelProviderID(dataset.embedding_model_provider)) if dataset.embedding_model_provider else None + ) + new_provider_str = ( + str(ModelProviderID(data["embedding_model_provider"])) if data["embedding_model_provider"] else None + ) + + # Only update if values are different + if current_provider_str != new_provider_str or data["embedding_model"] != dataset.embedding_model: + DatasetService._apply_new_embedding_settings(dataset, data, filtered_data) + return "update" + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + return None + + @staticmethod + def _apply_new_embedding_settings(dataset, data, filtered_data): + """ + Apply new embedding model settings to the dataset. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + """ + model_manager = ModelManager() + try: + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + except ProviderTokenNotInitError: + # If we can't get the embedding model, preserve existing settings + if dataset.embedding_model_provider and dataset.embedding_model: + filtered_data["embedding_model_provider"] = dataset.embedding_model_provider + filtered_data["embedding_model"] = dataset.embedding_model + if dataset.collection_binding_id: + filtered_data["collection_binding_id"] = dataset.collection_binding_id + # Skip the rest of the embedding model update + return + + # Apply new embedding model settings + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + @staticmethod def delete_dataset(dataset_id, user): dataset = DatasetService.get_dataset(dataset_id) diff --git a/api/tests/unit_tests/services/test_dataset_service.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py similarity index 100% rename from api/tests/unit_tests/services/test_dataset_service.py rename to api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py new file mode 100644 index 000000000..15e1b7569 --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -0,0 +1,826 @@ +import datetime + +# Mock redis_client before importing dataset_service +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.entities.model_entities import ModelType +from models.dataset import Dataset, ExternalKnowledgeBindings +from services.dataset_service import DatasetService +from services.errors.account import NoPermissionError +from tests.unit_tests.conftest import redis_mock + + +class TestDatasetServiceUpdateDataset: + """ + Comprehensive unit tests for DatasetService.update_dataset method. + + This test suite covers all supported scenarios including: + - External dataset updates + - Internal dataset updates with different indexing techniques + - Embedding model updates + - Permission checks + - Error conditions and edge cases + """ + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_external_dataset_success(self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db): + """ + Test successful update of external dataset. + + Verifies that: + 1. External dataset attributes are updated correctly + 2. External knowledge binding is updated when values change + 3. Database changes are committed + 4. Permission check is performed + """ + from unittest.mock import Mock, patch + + from extensions.ext_database import db + + with patch.object(db.__class__, "engine", new_callable=Mock): + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "external" + mock_dataset.name = "old_name" + mock_dataset.description = "old_description" + mock_dataset.retrieval_model = "old_model" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock external knowledge binding + mock_binding = Mock(spec=ExternalKnowledgeBindings) + mock_binding.external_knowledge_id = "old_knowledge_id" + mock_binding.external_knowledge_api_id = "old_api_id" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock external knowledge binding query + with patch("services.dataset_service.Session") as mock_session: + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.query.return_value.filter_by.return_value.first.return_value = mock_binding + + # Test data + update_data = { + "name": "new_name", + "description": "new_description", + "external_retrieval_model": "new_model", + "permission": "only_me", + "external_knowledge_id": "new_knowledge_id", + "external_knowledge_api_id": "new_api_id", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify permission check was called + mock_check_permission.assert_called_once_with(mock_dataset, mock_user) + + # Verify dataset attributes were updated + assert mock_dataset.name == "new_name" + assert mock_dataset.description == "new_description" + assert mock_dataset.retrieval_model == "new_model" + + # Verify external knowledge binding was updated + assert mock_binding.external_knowledge_id == "new_knowledge_id" + assert mock_binding.external_knowledge_api_id == "new_api_id" + + # Verify database operations + mock_db.add.assert_any_call(mock_dataset) + mock_db.add.assert_any_call(mock_binding) + mock_db.commit.assert_called_once() + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_external_dataset_missing_knowledge_id_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when external knowledge id is missing. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.provider = "external" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data without external_knowledge_id + update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"} + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "External knowledge id is required" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_external_dataset_missing_api_id_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when external knowledge api id is missing. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.provider = "external" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data without external_knowledge_api_id + update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"} + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "External knowledge api id is required" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.Session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_external_dataset_binding_not_found_error( + self, mock_check_permission, mock_get_dataset, mock_session, mock_db + ): + from unittest.mock import Mock, patch + + from extensions.ext_database import db + + with patch.object(db.__class__, "engine", new_callable=Mock): + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.provider = "external" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock external knowledge binding query returning None + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.query.return_value.filter_by.return_value.first.return_value = None + + # Test data + update_data = { + "name": "new_name", + "external_knowledge_id": "knowledge_id", + "external_knowledge_api_id": "api_id", + } + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "External knowledge binding not found" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_basic_success( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test successful update of internal dataset with basic fields. + + Verifies that: + 1. Basic dataset attributes are updated correctly + 2. Filtered data excludes None values except description + 3. Timestamp fields are updated + 4. Database changes are committed + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.name = "old_name" + mock_dataset.description = "old_description" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.retrieval_model = "old_model" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + mock_dataset.collection_binding_id = "binding-123" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data + update_data = { + "name": "new_name", + "description": "new_description", + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify permission check was called + mock_check_permission.assert_called_once_with(mock_dataset, mock_user) + + # Verify database update was called with correct filtered data + expected_filtered_data = { + "name": "new_name", + "description": "new_description", + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.deal_dataset_vector_index_task") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_indexing_technique_to_economy( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_task, mock_db + ): + """ + Test updating internal dataset indexing technique to economy. + + Verifies that: + 1. Embedding model fields are cleared when switching to economy + 2. Vector index task is triggered with 'remove' action + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data + update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"} + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with embedding model fields cleared + expected_filtered_data = { + "indexing_technique": "economy", + "embedding_model": None, + "embedding_model_provider": None, + "collection_binding_id": None, + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify vector index task was triggered + mock_task.delay.assert_called_once_with("dataset-123", "remove") + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_dataset_not_found_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when dataset is not found. + """ + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval returning None + mock_get_dataset.return_value = None + + # Test data + update_data = {"name": "new_name"} + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "Dataset not found" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_dataset_permission_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when user doesn't have permission. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock permission check to raise error + mock_check_permission.side_effect = NoPermissionError("No permission") + + # Test data + update_data = {"name": "new_name"} + + # Call the method and expect NoPermissionError + with pytest.raises(NoPermissionError): + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_keep_existing_embedding_model( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test updating internal dataset without changing embedding model. + + Verifies that: + 1. Existing embedding model settings are preserved when not provided in update + 2. No vector index task is triggered + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + mock_dataset.collection_binding_id = "binding-123" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data without embedding model fields + update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"} + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with existing embedding model preserved + expected_filtered_data = { + "name": "new_name", + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "collection_binding_id": "binding-123", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding") + @patch("services.dataset_service.ModelManager") + @patch("services.dataset_service.deal_dataset_vector_index_task") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_indexing_technique_to_high_quality( + self, + mock_datetime, + mock_check_permission, + mock_get_dataset, + mock_task, + mock_model_manager, + mock_collection_binding, + mock_db, + ): + """ + Test updating internal dataset indexing technique to high_quality. + + Verifies that: + 1. Embedding model is validated and set + 2. Collection binding is retrieved + 3. Vector index task is triggered with 'add' action + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "economy" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock embedding model + mock_embedding_model = Mock() + mock_embedding_model.model = "text-embedding-ada-002" + mock_embedding_model.provider = "openai" + + # Mock collection binding + mock_collection_binding_instance = Mock() + mock_collection_binding_instance.id = "binding-456" + + # Mock model manager + mock_model_manager_instance = Mock() + mock_model_manager_instance.get_model_instance.return_value = mock_embedding_model + mock_model_manager.return_value = mock_model_manager_instance + + # Mock collection binding service + mock_collection_binding.return_value = mock_collection_binding_instance + + # Mock current_user + mock_current_user = Mock() + mock_current_user.current_tenant_id = "tenant-123" + + # Test data + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "retrieval_model": "new_model", + } + + # Call the method with current_user mock + with patch("services.dataset_service.current_user", mock_current_user): + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify embedding model was validated + mock_model_manager_instance.get_model_instance.assert_called_once_with( + tenant_id=mock_current_user.current_tenant_id, + provider="openai", + model_type=ModelType.TEXT_EMBEDDING, + model="text-embedding-ada-002", + ) + + # Verify collection binding was retrieved + mock_collection_binding.assert_called_once_with("openai", "text-embedding-ada-002") + + # Verify database update was called with correct data + expected_filtered_data = { + "indexing_technique": "high_quality", + "embedding_model": "text-embedding-ada-002", + "embedding_model_provider": "openai", + "collection_binding_id": "binding-456", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify vector index task was triggered + mock_task.delay.assert_called_once_with("dataset-123", "add") + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_internal_dataset_embedding_model_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when embedding model is not available. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "economy" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock current_user + mock_current_user = Mock() + mock_current_user.current_tenant_id = "tenant-123" + + # Mock model manager to raise error + with ( + patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.current_user", mock_current_user), + ): + mock_model_manager_instance = Mock() + mock_model_manager_instance.get_model_instance.side_effect = Exception("No Embedding Model available") + mock_model_manager.return_value = mock_model_manager_instance + + # Test data + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "invalid_provider", + "embedding_model": "invalid_model", + "retrieval_model": "new_model", + } + + # Call the method and expect ValueError + with pytest.raises(Exception) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "No Embedding Model available".lower() in str(context.value).lower() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_filter_none_values( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test that None values are filtered out except for description field. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data with None values + update_data = { + "name": "new_name", + "description": None, # Should be included + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": None, # Should be filtered out + "embedding_model": None, # Should be filtered out + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with filtered data + expected_filtered_data = { + "name": "new_name", + "description": None, # Description should be included even if None + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": mock_db.query.return_value.filter_by.return_value.update.call_args[0][0]["updated_at"], + } + + actual_call_args = mock_db.query.return_value.filter_by.return_value.update.call_args[0][0] + # Remove timestamp for comparison as it's dynamic + del actual_call_args["updated_at"] + del expected_filtered_data["updated_at"] + + del actual_call_args["collection_binding_id"] + del actual_call_args["embedding_model"] + del actual_call_args["embedding_model_provider"] + + assert actual_call_args == expected_filtered_data + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.deal_dataset_vector_index_task") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_embedding_model_update( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_task, mock_db + ): + """ + Test updating internal dataset with new embedding model. + + Verifies that: + 1. Embedding model is updated when different from current + 2. Vector index task is triggered with 'update' action + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock embedding model + mock_embedding_model = Mock() + mock_embedding_model.model = "text-embedding-3-small" + mock_embedding_model.provider = "openai" + + # Mock collection binding + mock_collection_binding_instance = Mock() + mock_collection_binding_instance.id = "binding-789" + + # Mock current_user + mock_current_user = Mock() + mock_current_user.current_tenant_id = "tenant-123" + + # Mock model manager + with patch("services.dataset_service.ModelManager") as mock_model_manager: + mock_model_manager_instance = Mock() + mock_model_manager_instance.get_model_instance.return_value = mock_embedding_model + mock_model_manager.return_value = mock_model_manager_instance + + # Mock collection binding service + with ( + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" + ) as mock_collection_binding, + patch("services.dataset_service.current_user", mock_current_user), + ): + mock_collection_binding.return_value = mock_collection_binding_instance + + # Test data + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-3-small", + "retrieval_model": "new_model", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify embedding model was validated + mock_model_manager_instance.get_model_instance.assert_called_once_with( + tenant_id=mock_current_user.current_tenant_id, + provider="openai", + model_type=ModelType.TEXT_EMBEDDING, + model="text-embedding-3-small", + ) + + # Verify collection binding was retrieved + mock_collection_binding.assert_called_once_with("openai", "text-embedding-3-small") + + # Verify database update was called with correct data + expected_filtered_data = { + "indexing_technique": "high_quality", + "embedding_model": "text-embedding-3-small", + "embedding_model_provider": "openai", + "collection_binding_id": "binding-789", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify vector index task was triggered + mock_task.delay.assert_called_once_with("dataset-123", "update") + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_no_indexing_technique_change( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test updating internal dataset without changing indexing technique. + + Verifies that: + 1. No vector index task is triggered when indexing technique doesn't change + 2. Database update is performed normally + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + mock_dataset.collection_binding_id = "binding-123" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data with same indexing technique + update_data = { + "name": "new_name", + "indexing_technique": "high_quality", # Same as current + "retrieval_model": "new_model", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with correct data + expected_filtered_data = { + "name": "new_name", + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "collection_binding_id": "binding-123", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify no vector index task was triggered + mock_db.query.return_value.filter_by.return_value.update.assert_called_once() + + # Verify return value + assert result == mock_dataset