Refactor update dataset (fix #21401) (#21402)

Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
NeatGuyCoding
2025-06-25 11:44:35 +08:00
committed by GitHub
parent 819c02f1f5
commit 94f8e48647
3 changed files with 1133 additions and 153 deletions

View File

@@ -280,28 +280,93 @@ class DatasetService:
@staticmethod
def update_dataset(dataset_id, data, user):
"""
Update dataset configuration and settings.
Args:
dataset_id: The unique identifier of the dataset to update
data: Dictionary containing the update data
user: The user performing the update operation
Returns:
Dataset: The updated dataset object
Raises:
ValueError: If dataset not found or validation fails
NoPermissionError: If user lacks permission to update the dataset
"""
# Retrieve and validate dataset existence
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise ValueError("Dataset not found")
# Verify user has permission to update this dataset
DatasetService.check_dataset_permission(dataset, user)
# Handle external dataset updates
if dataset.provider == "external":
return DatasetService._update_external_dataset(dataset, data, user)
else:
return DatasetService._update_internal_dataset(dataset, data, user)
@staticmethod
def _update_external_dataset(dataset, data, user):
"""
Update external dataset configuration.
Args:
dataset: The dataset object to update
data: Update data dictionary
user: User performing the update
Returns:
Dataset: Updated dataset object
"""
# Update retrieval model if provided
external_retrieval_model = data.get("external_retrieval_model", None)
if external_retrieval_model:
dataset.retrieval_model = external_retrieval_model
# Update basic dataset properties
dataset.name = data.get("name", dataset.name)
dataset.description = data.get("description", "")
dataset.description = data.get("description", dataset.description)
# Update permission if provided
permission = data.get("permission")
if permission:
dataset.permission = permission
# Validate and update external knowledge configuration
external_knowledge_id = data.get("external_knowledge_id", None)
db.session.add(dataset)
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
if not external_knowledge_id:
raise ValueError("External knowledge id is required.")
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
if not external_knowledge_api_id:
raise ValueError("External knowledge api id is required.")
# Update metadata fields
dataset.updated_by = user.id if user else None
dataset.updated_at = datetime.datetime.utcnow()
db.session.add(dataset)
# Update external knowledge binding
DatasetService._update_external_knowledge_binding(dataset.id, external_knowledge_id, external_knowledge_api_id)
# Commit changes to database
db.session.commit()
return dataset
@staticmethod
def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id):
"""
Update external knowledge binding configuration.
Args:
dataset_id: Dataset identifier
external_knowledge_id: External knowledge identifier
external_knowledge_api_id: External knowledge API identifier
"""
with Session(db.engine) as session:
external_knowledge_binding = (
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
@@ -310,6 +375,7 @@ class DatasetService:
if not external_knowledge_binding:
raise ValueError("External knowledge binding not found.")
# Update binding if values have changed
if (
external_knowledge_binding.external_knowledge_id != external_knowledge_id
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
@@ -317,24 +383,86 @@ class DatasetService:
external_knowledge_binding.external_knowledge_id = external_knowledge_id
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
db.session.add(external_knowledge_binding)
db.session.commit()
else:
@staticmethod
def _update_internal_dataset(dataset, data, user):
"""
Update internal dataset configuration.
Args:
dataset: The dataset object to update
data: Update data dictionary
user: User performing the update
Returns:
Dataset: Updated dataset object
"""
# Remove external-specific fields from update data
data.pop("partial_member_list", None)
data.pop("external_knowledge_api_id", None)
data.pop("external_knowledge_id", None)
data.pop("external_retrieval_model", None)
# Filter out None values except for description field
filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
action = None
# Handle indexing technique changes and embedding model updates
action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data)
# Add metadata fields
filtered_data["updated_by"] = user.id
filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
# Update dataset in database
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
db.session.commit()
# Trigger vector index task if indexing technique changed
if action:
deal_dataset_vector_index_task.delay(dataset.id, action)
return dataset
@staticmethod
def _handle_indexing_technique_change(dataset, data, filtered_data):
"""
Handle changes in indexing technique and configure embedding models accordingly.
Args:
dataset: Current dataset object
data: Update data dictionary
filtered_data: Filtered update data
Returns:
str: Action to perform ('add', 'remove', 'update', or None)
"""
if dataset.indexing_technique != data["indexing_technique"]:
# if update indexing_technique
if data["indexing_technique"] == "economy":
action = "remove"
# Remove embedding model configuration for economy mode
filtered_data["embedding_model"] = None
filtered_data["embedding_model_provider"] = None
filtered_data["collection_binding_id"] = None
return "remove"
elif data["indexing_technique"] == "high_quality":
action = "add"
# get embedding model setting
# Configure embedding model for high quality mode
DatasetService._configure_embedding_model_for_high_quality(data, filtered_data)
return "add"
else:
# Handle embedding model updates when indexing technique remains the same
return DatasetService._handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data)
return None
@staticmethod
def _configure_embedding_model_for_high_quality(data, filtered_data):
"""
Configure embedding model settings for high quality indexing.
Args:
data: Update data dictionary
filtered_data: Filtered update data to modify
"""
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
@@ -351,13 +479,24 @@ class DatasetService:
filtered_data["collection_binding_id"] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
else:
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
@staticmethod
def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data):
"""
Handle embedding model updates when indexing technique remains the same.
Args:
dataset: Current dataset object
data: Update data dictionary
filtered_data: Filtered update data to modify
Returns:
str: Action to perform ('update' or None)
"""
# Skip embedding model checks if not provided in the update request
if (
"embedding_model_provider" not in data
@@ -365,9 +504,22 @@ class DatasetService:
or not data.get("embedding_model_provider")
or not data.get("embedding_model")
):
DatasetService._preserve_existing_embedding_settings(dataset, filtered_data)
return None
else:
return DatasetService._update_embedding_model_settings(dataset, data, filtered_data)
@staticmethod
def _preserve_existing_embedding_settings(dataset, filtered_data):
"""
Preserve existing embedding model settings when not provided in update.
Args:
dataset: Current dataset object
filtered_data: Filtered update data to modify
"""
# If the dataset already has embedding model settings, use those
if dataset.embedding_model_provider and dataset.embedding_model:
# Keep existing values
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
filtered_data["embedding_model"] = dataset.embedding_model
# If collection_binding_id exists, keep it too
@@ -379,27 +531,51 @@ class DatasetService:
del filtered_data["embedding_model_provider"]
if "embedding_model" in filtered_data and not filtered_data["embedding_model"]:
del filtered_data["embedding_model"]
else:
skip_embedding_update = False
@staticmethod
def _update_embedding_model_settings(dataset, data, filtered_data):
"""
Update embedding model settings with new values.
Args:
dataset: Current dataset object
data: Update data dictionary
filtered_data: Filtered update data to modify
Returns:
str: Action to perform ('update' or None)
"""
try:
# Handle existing model provider
plugin_model_provider = dataset.embedding_model_provider
plugin_model_provider_str = None
if plugin_model_provider:
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
# Compare current and new model provider settings
current_provider_str = (
str(ModelProviderID(dataset.embedding_model_provider)) if dataset.embedding_model_provider else None
)
new_provider_str = (
str(ModelProviderID(data["embedding_model_provider"])) if data["embedding_model_provider"] else None
)
# Handle new model provider from request
new_plugin_model_provider = data["embedding_model_provider"]
new_plugin_model_provider_str = None
if new_plugin_model_provider:
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
# Only update if values are different
if current_provider_str != new_provider_str or data["embedding_model"] != dataset.embedding_model:
DatasetService._apply_new_embedding_settings(dataset, data, filtered_data)
return "update"
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
return None
# Only update embedding model if both values are provided and different from current
if (
plugin_model_provider_str != new_plugin_model_provider_str
or data["embedding_model"] != dataset.embedding_model
):
action = "update"
@staticmethod
def _apply_new_embedding_settings(dataset, data, filtered_data):
"""
Apply new embedding model settings to the dataset.
Args:
dataset: Current dataset object
data: Update data dictionary
filtered_data: Filtered update data to modify
"""
model_manager = ModelManager()
try:
embedding_model = model_manager.get_model_instance(
@@ -409,44 +585,22 @@ class DatasetService:
model=data["embedding_model"],
)
except ProviderTokenNotInitError:
# If we can't get the embedding model, skip updating it
# and keep the existing settings if available
# If we can't get the embedding model, preserve existing settings
if dataset.embedding_model_provider and dataset.embedding_model:
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
filtered_data["embedding_model"] = dataset.embedding_model
if dataset.collection_binding_id:
filtered_data["collection_binding_id"] = dataset.collection_binding_id
# Skip the rest of the embedding model update
skip_embedding_update = True
if not skip_embedding_update:
return
# Apply new embedding model settings
filtered_data["embedding_model"] = embedding_model.model
filtered_data["embedding_model_provider"] = embedding_model.provider
dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding(
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
filtered_data["updated_by"] = user.id
filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
db.session.commit()
if action:
deal_dataset_vector_index_task.delay(dataset_id, action)
return dataset
@staticmethod
def delete_dataset(dataset_id, user):

View File

@@ -0,0 +1,826 @@
import datetime
# Mock redis_client before importing dataset_service
from unittest.mock import Mock, patch
import pytest
from core.model_runtime.entities.model_entities import ModelType
from models.dataset import Dataset, ExternalKnowledgeBindings
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
from tests.unit_tests.conftest import redis_mock
class TestDatasetServiceUpdateDataset:
"""
Comprehensive unit tests for DatasetService.update_dataset method.
This test suite covers all supported scenarios including:
- External dataset updates
- Internal dataset updates with different indexing techniques
- Embedding model updates
- Permission checks
- Error conditions and edge cases
"""
@patch("extensions.ext_database.db.session")
@patch("services.dataset_service.DatasetService.get_dataset")
@patch("services.dataset_service.DatasetService.check_dataset_permission")
@patch("services.dataset_service.datetime")
def test_update_external_dataset_success(self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db):
"""
Test successful update of external dataset.
Verifies that:
1. External dataset attributes are updated correctly
2. External knowledge binding is updated when values change
3. Database changes are committed
4. Permission check is performed
"""
from unittest.mock import Mock, patch
from extensions.ext_database import db
with patch.object(db.__class__, "engine", new_callable=Mock):
# Create mock dataset
mock_dataset = Mock(spec=Dataset)
mock_dataset.id = "dataset-123"
mock_dataset.provider = "external"
mock_dataset.name = "old_name"
mock_dataset.description = "old_description"
mock_dataset.retrieval_model = "old_model"
# Create mock user
mock_user = Mock()
mock_user.id = "user-789"
# Create mock external knowledge binding
mock_binding = Mock(spec=ExternalKnowledgeBindings)
mock_binding.external_knowledge_id = "old_knowledge_id"
mock_binding.external_knowledge_api_id = "old_api_id"
# Set up mock return values
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.datetime.now.return_value = current_time
mock_datetime.UTC = datetime.UTC
# Mock dataset retrieval
mock_get_dataset.return_value = mock_dataset
# Mock external knowledge binding query
with patch("services.dataset_service.Session") as mock_session:
mock_session_instance = Mock()
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_session_instance.query.return_value.filter_by.return_value.first.return_value = mock_binding
# Test data
update_data = {
"name": "new_name",
"description": "new_description",
"external_retrieval_model": "new_model",
"permission": "only_me",
"external_knowledge_id": "new_knowledge_id",
"external_knowledge_api_id": "new_api_id",
}
# Call the method
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
# Verify permission check was called
mock_check_permission.assert_called_once_with(mock_dataset, mock_user)
# Verify dataset attributes were updated
assert mock_dataset.name == "new_name"
assert mock_dataset.description == "new_description"
assert mock_dataset.retrieval_model == "new_model"
# Verify external knowledge binding was updated
assert mock_binding.external_knowledge_id == "new_knowledge_id"
assert mock_binding.external_knowledge_api_id == "new_api_id"
# Verify database operations
mock_db.add.assert_any_call(mock_dataset)
mock_db.add.assert_any_call(mock_binding)
mock_db.commit.assert_called_once()
# Verify return value
assert result == mock_dataset
@patch("extensions.ext_database.db.session")
@patch("services.dataset_service.DatasetService.get_dataset")
@patch("services.dataset_service.DatasetService.check_dataset_permission")
def test_update_external_dataset_missing_knowledge_id_error(self, mock_check_permission, mock_get_dataset, mock_db):
"""
Test error when external knowledge id is missing.
"""
# Create mock dataset
mock_dataset = Mock(spec=Dataset)
mock_dataset.provider = "external"
# Create mock user
mock_user = Mock()
# Mock dataset retrieval
mock_get_dataset.return_value = mock_dataset
# Test data without external_knowledge_id
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
# Call the method and expect ValueError
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, mock_user)
assert "External knowledge id is required" in str(context.value)
@patch("extensions.ext_database.db.session")
@patch("services.dataset_service.DatasetService.get_dataset")
@patch("services.dataset_service.DatasetService.check_dataset_permission")
def test_update_external_dataset_missing_api_id_error(self, mock_check_permission, mock_get_dataset, mock_db):
"""
Test error when external knowledge api id is missing.
"""
# Create mock dataset
mock_dataset = Mock(spec=Dataset)
mock_dataset.provider = "external"
# Create mock user
mock_user = Mock()
# Mock dataset retrieval
mock_get_dataset.return_value = mock_dataset
# Test data without external_knowledge_api_id
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
# Call the method and expect ValueError
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, mock_user)
assert "External knowledge api id is required" in str(context.value)
@patch("extensions.ext_database.db.session")
@patch("services.dataset_service.Session")
@patch("services.dataset_service.DatasetService.get_dataset")
@patch("services.dataset_service.DatasetService.check_dataset_permission")
def test_update_external_dataset_binding_not_found_error(
self, mock_check_permission, mock_get_dataset, mock_session, mock_db
):
from unittest.mock import Mock, patch
from extensions.ext_database import db
with patch.object(db.__class__, "engine", new_callable=Mock):
# Create mock dataset
mock_dataset = Mock(spec=Dataset)
mock_dataset.provider = "external"
# Create mock user
mock_user = Mock()
# Mock dataset retrieval
mock_get_dataset.return_value = mock_dataset
# Mock external knowledge binding query returning None
mock_session_instance = Mock()
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_session_instance.query.return_value.filter_by.return_value.first.return_value = None
# Test data
update_data = {
"name": "new_name",
"external_knowledge_id": "knowledge_id",
"external_knowledge_api_id": "api_id",
}
# Call the method and expect ValueError
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, mock_user)
assert "External knowledge binding not found" in str(context.value)
@patch("extensions.ext_database.db.session")
@patch("services.dataset_service.DatasetService.get_dataset")
@patch("services.dataset_service.DatasetService.check_dataset_permission")
@patch("services.dataset_service.datetime")
def test_update_internal_dataset_basic_success(
self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db
):
"""
Test successful update of internal dataset with basic fields.
Verifies that:
1. Basic dataset attributes are updated correctly
2. Filtered data excludes None values except description
3. Timestamp fields are updated
4. Database changes are committed
"""
# Create mock dataset
mock_dataset = Mock(spec=Dataset)
mock_dataset.id = "dataset-123"
mock_dataset.provider = "vendor"
mock_dataset.name = "old_name"
mock_dataset.description = "old_description"
mock_dataset.indexing_technique = "high_quality"
mock_dataset.retrieval_model = "old_model"
mock_dataset.embedding_model_provider = "openai"
mock_dataset.embedding_model = "text-embedding-ada-002"
mock_dataset.collection_binding_id = "binding-123"
# Create mock user
mock_user = Mock()
mock_user.id = "user-789"
# Set up mock return values
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.datetime.now.return_value = current_time
mock_datetime.UTC = datetime.UTC
# Mock dataset retrieval
mock_get_dataset.return_value = mock_dataset
# Test data
update_data = {
"name": "new_name",
"description": "new_description",
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
}
# Call the method
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
# Verify permission check was called
mock_check_permission.assert_called_once_with(mock_dataset, mock_user)
# Verify database update was called with correct filtered data
expected_filtered_data = {
"name": "new_name",
"description": "new_description",
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"updated_by": mock_user.id,
"updated_at": current_time.replace(tzinfo=None),
}
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data)
mock_db.commit.assert_called_once()
# Verify return value
assert result == mock_dataset
@patch("extensions.ext_database.db.session")
@patch("services.dataset_service.deal_dataset_vector_index_task")
@patch("services.dataset_service.DatasetService.get_dataset")
@patch("services.dataset_service.DatasetService.check_dataset_permission")
@patch("services.dataset_service.datetime")
def test_update_internal_dataset_indexing_technique_to_economy(
self, mock_datetime, mock_check_permission, mock_get_dataset, mock_task, mock_db
):
"""
Test updating internal dataset indexing technique to economy.
Verifies that:
1. Embedding model fields are cleared when switching to economy
2. Vector index task is triggered with 'remove' action
"""
# Create mock dataset
mock_dataset = Mock(spec=Dataset)
mock_dataset.id = "dataset-123"
mock_dataset.provider = "vendor"
mock_dataset.indexing_technique = "high_quality"
# Create mock user
mock_user = Mock()
mock_user.id = "user-789"
# Set up mock return values
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.datetime.now.return_value = current_time
mock_datetime.UTC = datetime.UTC
# Mock dataset retrieval
mock_get_dataset.return_value = mock_dataset
# Test data
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
# Call the method
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
# Verify database update was called with embedding model fields cleared
expected_filtered_data = {
"indexing_technique": "economy",
"embedding_model": None,
"embedding_model_provider": None,
"collection_binding_id": None,
"retrieval_model": "new_model",
"updated_by": mock_user.id,
"updated_at": current_time.replace(tzinfo=None),
}
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data)
mock_db.commit.assert_called_once()
# Verify vector index task was triggered
mock_task.delay.assert_called_once_with("dataset-123", "remove")
# Verify return value
assert result == mock_dataset
@patch("extensions.ext_database.db.session")
@patch("services.dataset_service.DatasetService.get_dataset")
@patch("services.dataset_service.DatasetService.check_dataset_permission")
def test_update_dataset_not_found_error(self, mock_check_permission, mock_get_dataset, mock_db):
"""
Test error when dataset is not found.
"""
# Create mock user
mock_user = Mock()
# Mock dataset retrieval returning None
mock_get_dataset.return_value = None
# Test data
update_data = {"name": "new_name"}
# Call the method and expect ValueError
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, mock_user)
assert "Dataset not found" in str(context.value)
@patch("extensions.ext_database.db.session")
@patch("services.dataset_service.DatasetService.get_dataset")
@patch("services.dataset_service.DatasetService.check_dataset_permission")
def test_update_dataset_permission_error(self, mock_check_permission, mock_get_dataset, mock_db):
"""
Test error when user doesn't have permission.
"""
# Create mock dataset
mock_dataset = Mock(spec=Dataset)
mock_dataset.id = "dataset-123"
# Create mock user
mock_user = Mock()
# Mock dataset retrieval
mock_get_dataset.return_value = mock_dataset
# Mock permission check to raise error
mock_check_permission.side_effect = NoPermissionError("No permission")
# Test data
update_data = {"name": "new_name"}
# Call the method and expect NoPermissionError
with pytest.raises(NoPermissionError):
DatasetService.update_dataset("dataset-123", update_data, mock_user)
@patch("extensions.ext_database.db.session")
@patch("services.dataset_service.DatasetService.get_dataset")
@patch("services.dataset_service.DatasetService.check_dataset_permission")
@patch("services.dataset_service.datetime")
def test_update_internal_dataset_keep_existing_embedding_model(
self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db
):
"""
Test updating internal dataset without changing embedding model.
Verifies that:
1. Existing embedding model settings are preserved when not provided in update
2. No vector index task is triggered
"""
# Create mock dataset
mock_dataset = Mock(spec=Dataset)
mock_dataset.id = "dataset-123"
mock_dataset.provider = "vendor"
mock_dataset.indexing_technique = "high_quality"
mock_dataset.embedding_model_provider = "openai"
mock_dataset.embedding_model = "text-embedding-ada-002"
mock_dataset.collection_binding_id = "binding-123"
# Create mock user
mock_user = Mock()
mock_user.id = "user-789"
# Set up mock return values
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.datetime.now.return_value = current_time
mock_datetime.UTC = datetime.UTC
# Mock dataset retrieval
mock_get_dataset.return_value = mock_dataset
# Test data without embedding model fields
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
# Call the method
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
# Verify database update was called with existing embedding model preserved
expected_filtered_data = {
"name": "new_name",
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"collection_binding_id": "binding-123",
"retrieval_model": "new_model",
"updated_by": mock_user.id,
"updated_at": current_time.replace(tzinfo=None),
}
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data)
mock_db.commit.assert_called_once()
# Verify return value
assert result == mock_dataset
@patch("extensions.ext_database.db.session")
@patch("services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding")
@patch("services.dataset_service.ModelManager")
@patch("services.dataset_service.deal_dataset_vector_index_task")
@patch("services.dataset_service.DatasetService.get_dataset")
@patch("services.dataset_service.DatasetService.check_dataset_permission")
@patch("services.dataset_service.datetime")
def test_update_internal_dataset_indexing_technique_to_high_quality(
self,
mock_datetime,
mock_check_permission,
mock_get_dataset,
mock_task,
mock_model_manager,
mock_collection_binding,
mock_db,
):
"""
Test updating internal dataset indexing technique to high_quality.
Verifies that:
1. Embedding model is validated and set
2. Collection binding is retrieved
3. Vector index task is triggered with 'add' action
"""
# Create mock dataset
mock_dataset = Mock(spec=Dataset)
mock_dataset.id = "dataset-123"
mock_dataset.provider = "vendor"
mock_dataset.indexing_technique = "economy"
# Create mock user
mock_user = Mock()
mock_user.id = "user-789"
# Set up mock return values
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.datetime.now.return_value = current_time
mock_datetime.UTC = datetime.UTC
# Mock dataset retrieval
mock_get_dataset.return_value = mock_dataset
# Mock embedding model
mock_embedding_model = Mock()
mock_embedding_model.model = "text-embedding-ada-002"
mock_embedding_model.provider = "openai"
# Mock collection binding
mock_collection_binding_instance = Mock()
mock_collection_binding_instance.id = "binding-456"
# Mock model manager
mock_model_manager_instance = Mock()
mock_model_manager_instance.get_model_instance.return_value = mock_embedding_model
mock_model_manager.return_value = mock_model_manager_instance
# Mock collection binding service
mock_collection_binding.return_value = mock_collection_binding_instance
# Mock current_user
mock_current_user = Mock()
mock_current_user.current_tenant_id = "tenant-123"
# Test data
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"retrieval_model": "new_model",
}
# Call the method with current_user mock
with patch("services.dataset_service.current_user", mock_current_user):
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
# Verify embedding model was validated
mock_model_manager_instance.get_model_instance.assert_called_once_with(
tenant_id=mock_current_user.current_tenant_id,
provider="openai",
model_type=ModelType.TEXT_EMBEDDING,
model="text-embedding-ada-002",
)
# Verify collection binding was retrieved
mock_collection_binding.assert_called_once_with("openai", "text-embedding-ada-002")
# Verify database update was called with correct data
expected_filtered_data = {
"indexing_technique": "high_quality",
"embedding_model": "text-embedding-ada-002",
"embedding_model_provider": "openai",
"collection_binding_id": "binding-456",
"retrieval_model": "new_model",
"updated_by": mock_user.id,
"updated_at": current_time.replace(tzinfo=None),
}
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data)
mock_db.commit.assert_called_once()
# Verify vector index task was triggered
mock_task.delay.assert_called_once_with("dataset-123", "add")
# Verify return value
assert result == mock_dataset
@patch("extensions.ext_database.db.session")
@patch("services.dataset_service.DatasetService.get_dataset")
@patch("services.dataset_service.DatasetService.check_dataset_permission")
def test_update_internal_dataset_embedding_model_error(self, mock_check_permission, mock_get_dataset, mock_db):
"""
Test error when embedding model is not available.
"""
# Create mock dataset
mock_dataset = Mock(spec=Dataset)
mock_dataset.id = "dataset-123"
mock_dataset.provider = "vendor"
mock_dataset.indexing_technique = "economy"
# Create mock user
mock_user = Mock()
mock_user.id = "user-789"
# Mock dataset retrieval
mock_get_dataset.return_value = mock_dataset
# Mock current_user
mock_current_user = Mock()
mock_current_user.current_tenant_id = "tenant-123"
# Mock model manager to raise error
with (
patch("services.dataset_service.ModelManager") as mock_model_manager,
patch("services.dataset_service.current_user", mock_current_user),
):
mock_model_manager_instance = Mock()
mock_model_manager_instance.get_model_instance.side_effect = Exception("No Embedding Model available")
mock_model_manager.return_value = mock_model_manager_instance
# Test data
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "invalid_provider",
"embedding_model": "invalid_model",
"retrieval_model": "new_model",
}
# Call the method and expect ValueError
with pytest.raises(Exception) as context:
DatasetService.update_dataset("dataset-123", update_data, mock_user)
assert "No Embedding Model available".lower() in str(context.value).lower()
@patch("extensions.ext_database.db.session")
@patch("services.dataset_service.DatasetService.get_dataset")
@patch("services.dataset_service.DatasetService.check_dataset_permission")
@patch("services.dataset_service.datetime")
def test_update_internal_dataset_filter_none_values(
self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db
):
"""
Test that None values are filtered out except for description field.
"""
# Create mock dataset
mock_dataset = Mock(spec=Dataset)
mock_dataset.id = "dataset-123"
mock_dataset.provider = "vendor"
mock_dataset.indexing_technique = "high_quality"
# Create mock user
mock_user = Mock()
mock_user.id = "user-789"
# Mock dataset retrieval
mock_get_dataset.return_value = mock_dataset
# Test data with None values
update_data = {
"name": "new_name",
"description": None, # Should be included
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"embedding_model_provider": None, # Should be filtered out
"embedding_model": None, # Should be filtered out
}
# Call the method
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
# Verify database update was called with filtered data
expected_filtered_data = {
"name": "new_name",
"description": None, # Description should be included even if None
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"updated_by": mock_user.id,
"updated_at": mock_db.query.return_value.filter_by.return_value.update.call_args[0][0]["updated_at"],
}
actual_call_args = mock_db.query.return_value.filter_by.return_value.update.call_args[0][0]
# Remove timestamp for comparison as it's dynamic
del actual_call_args["updated_at"]
del expected_filtered_data["updated_at"]
del actual_call_args["collection_binding_id"]
del actual_call_args["embedding_model"]
del actual_call_args["embedding_model_provider"]
assert actual_call_args == expected_filtered_data
# Verify return value
assert result == mock_dataset
@patch("extensions.ext_database.db.session")
@patch("services.dataset_service.deal_dataset_vector_index_task")
@patch("services.dataset_service.DatasetService.get_dataset")
@patch("services.dataset_service.DatasetService.check_dataset_permission")
@patch("services.dataset_service.datetime")
def test_update_internal_dataset_embedding_model_update(
self, mock_datetime, mock_check_permission, mock_get_dataset, mock_task, mock_db
):
"""
Test updating internal dataset with new embedding model.
Verifies that:
1. Embedding model is updated when different from current
2. Vector index task is triggered with 'update' action
"""
# Create mock dataset
mock_dataset = Mock(spec=Dataset)
mock_dataset.id = "dataset-123"
mock_dataset.provider = "vendor"
mock_dataset.indexing_technique = "high_quality"
mock_dataset.embedding_model_provider = "openai"
mock_dataset.embedding_model = "text-embedding-ada-002"
# Create mock user
mock_user = Mock()
mock_user.id = "user-789"
# Set up mock return values
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.datetime.now.return_value = current_time
mock_datetime.UTC = datetime.UTC
# Mock dataset retrieval
mock_get_dataset.return_value = mock_dataset
# Mock embedding model
mock_embedding_model = Mock()
mock_embedding_model.model = "text-embedding-3-small"
mock_embedding_model.provider = "openai"
# Mock collection binding
mock_collection_binding_instance = Mock()
mock_collection_binding_instance.id = "binding-789"
# Mock current_user
mock_current_user = Mock()
mock_current_user.current_tenant_id = "tenant-123"
# Mock model manager
with patch("services.dataset_service.ModelManager") as mock_model_manager:
mock_model_manager_instance = Mock()
mock_model_manager_instance.get_model_instance.return_value = mock_embedding_model
mock_model_manager.return_value = mock_model_manager_instance
# Mock collection binding service
with (
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
) as mock_collection_binding,
patch("services.dataset_service.current_user", mock_current_user),
):
mock_collection_binding.return_value = mock_collection_binding_instance
# Test data
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-3-small",
"retrieval_model": "new_model",
}
# Call the method
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
# Verify embedding model was validated
mock_model_manager_instance.get_model_instance.assert_called_once_with(
tenant_id=mock_current_user.current_tenant_id,
provider="openai",
model_type=ModelType.TEXT_EMBEDDING,
model="text-embedding-3-small",
)
# Verify collection binding was retrieved
mock_collection_binding.assert_called_once_with("openai", "text-embedding-3-small")
# Verify database update was called with correct data
expected_filtered_data = {
"indexing_technique": "high_quality",
"embedding_model": "text-embedding-3-small",
"embedding_model_provider": "openai",
"collection_binding_id": "binding-789",
"retrieval_model": "new_model",
"updated_by": mock_user.id,
"updated_at": current_time.replace(tzinfo=None),
}
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data)
mock_db.commit.assert_called_once()
# Verify vector index task was triggered
mock_task.delay.assert_called_once_with("dataset-123", "update")
# Verify return value
assert result == mock_dataset
@patch("extensions.ext_database.db.session")
@patch("services.dataset_service.DatasetService.get_dataset")
@patch("services.dataset_service.DatasetService.check_dataset_permission")
@patch("services.dataset_service.datetime")
def test_update_internal_dataset_no_indexing_technique_change(
self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db
):
"""
Test updating internal dataset without changing indexing technique.
Verifies that:
1. No vector index task is triggered when indexing technique doesn't change
2. Database update is performed normally
"""
# Create mock dataset
mock_dataset = Mock(spec=Dataset)
mock_dataset.id = "dataset-123"
mock_dataset.provider = "vendor"
mock_dataset.indexing_technique = "high_quality"
mock_dataset.embedding_model_provider = "openai"
mock_dataset.embedding_model = "text-embedding-ada-002"
mock_dataset.collection_binding_id = "binding-123"
# Create mock user
mock_user = Mock()
mock_user.id = "user-789"
# Set up mock return values
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.datetime.now.return_value = current_time
mock_datetime.UTC = datetime.UTC
# Mock dataset retrieval
mock_get_dataset.return_value = mock_dataset
# Test data with same indexing technique
update_data = {
"name": "new_name",
"indexing_technique": "high_quality", # Same as current
"retrieval_model": "new_model",
}
# Call the method
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
# Verify database update was called with correct data
expected_filtered_data = {
"name": "new_name",
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"collection_binding_id": "binding-123",
"retrieval_model": "new_model",
"updated_by": mock_user.id,
"updated_at": current_time.replace(tzinfo=None),
}
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data)
mock_db.commit.assert_called_once()
# Verify no vector index task was triggered
mock_db.query.return_value.filter_by.return_value.update.assert_called_once()
# Verify return value
assert result == mock_dataset