Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -280,174 +280,328 @@ class DatasetService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_dataset(dataset_id, data, user):
|
def update_dataset(dataset_id, data, user):
|
||||||
|
"""
|
||||||
|
Update dataset configuration and settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: The unique identifier of the dataset to update
|
||||||
|
data: Dictionary containing the update data
|
||||||
|
user: The user performing the update operation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset: The updated dataset object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If dataset not found or validation fails
|
||||||
|
NoPermissionError: If user lacks permission to update the dataset
|
||||||
|
"""
|
||||||
|
# Retrieve and validate dataset existence
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("Dataset not found")
|
raise ValueError("Dataset not found")
|
||||||
|
|
||||||
|
# Verify user has permission to update this dataset
|
||||||
DatasetService.check_dataset_permission(dataset, user)
|
DatasetService.check_dataset_permission(dataset, user)
|
||||||
|
|
||||||
|
# Handle external dataset updates
|
||||||
if dataset.provider == "external":
|
if dataset.provider == "external":
|
||||||
external_retrieval_model = data.get("external_retrieval_model", None)
|
return DatasetService._update_external_dataset(dataset, data, user)
|
||||||
if external_retrieval_model:
|
|
||||||
dataset.retrieval_model = external_retrieval_model
|
|
||||||
dataset.name = data.get("name", dataset.name)
|
|
||||||
dataset.description = data.get("description", "")
|
|
||||||
permission = data.get("permission")
|
|
||||||
if permission:
|
|
||||||
dataset.permission = permission
|
|
||||||
external_knowledge_id = data.get("external_knowledge_id", None)
|
|
||||||
db.session.add(dataset)
|
|
||||||
if not external_knowledge_id:
|
|
||||||
raise ValueError("External knowledge id is required.")
|
|
||||||
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
|
|
||||||
if not external_knowledge_api_id:
|
|
||||||
raise ValueError("External knowledge api id is required.")
|
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
|
||||||
external_knowledge_binding = (
|
|
||||||
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not external_knowledge_binding:
|
|
||||||
raise ValueError("External knowledge binding not found.")
|
|
||||||
|
|
||||||
if (
|
|
||||||
external_knowledge_binding.external_knowledge_id != external_knowledge_id
|
|
||||||
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
|
|
||||||
):
|
|
||||||
external_knowledge_binding.external_knowledge_id = external_knowledge_id
|
|
||||||
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
|
|
||||||
db.session.add(external_knowledge_binding)
|
|
||||||
db.session.commit()
|
|
||||||
else:
|
else:
|
||||||
data.pop("partial_member_list", None)
|
return DatasetService._update_internal_dataset(dataset, data, user)
|
||||||
data.pop("external_knowledge_api_id", None)
|
|
||||||
data.pop("external_knowledge_id", None)
|
|
||||||
data.pop("external_retrieval_model", None)
|
|
||||||
filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
|
|
||||||
action = None
|
|
||||||
if dataset.indexing_technique != data["indexing_technique"]:
|
|
||||||
# if update indexing_technique
|
|
||||||
if data["indexing_technique"] == "economy":
|
|
||||||
action = "remove"
|
|
||||||
filtered_data["embedding_model"] = None
|
|
||||||
filtered_data["embedding_model_provider"] = None
|
|
||||||
filtered_data["collection_binding_id"] = None
|
|
||||||
elif data["indexing_technique"] == "high_quality":
|
|
||||||
action = "add"
|
|
||||||
# get embedding model setting
|
|
||||||
try:
|
|
||||||
model_manager = ModelManager()
|
|
||||||
embedding_model = model_manager.get_model_instance(
|
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
provider=data["embedding_model_provider"],
|
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
model=data["embedding_model"],
|
|
||||||
)
|
|
||||||
filtered_data["embedding_model"] = embedding_model.model
|
|
||||||
filtered_data["embedding_model_provider"] = embedding_model.provider
|
|
||||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
|
||||||
embedding_model.provider, embedding_model.model
|
|
||||||
)
|
|
||||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
|
||||||
except LLMBadRequestError:
|
|
||||||
raise ValueError(
|
|
||||||
"No Embedding Model available. Please configure a valid provider "
|
|
||||||
"in the Settings -> Model Provider."
|
|
||||||
)
|
|
||||||
except ProviderTokenNotInitError as ex:
|
|
||||||
raise ValueError(ex.description)
|
|
||||||
else:
|
|
||||||
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
|
|
||||||
# Skip embedding model checks if not provided in the update request
|
|
||||||
if (
|
|
||||||
"embedding_model_provider" not in data
|
|
||||||
or "embedding_model" not in data
|
|
||||||
or not data.get("embedding_model_provider")
|
|
||||||
or not data.get("embedding_model")
|
|
||||||
):
|
|
||||||
# If the dataset already has embedding model settings, use those
|
|
||||||
if dataset.embedding_model_provider and dataset.embedding_model:
|
|
||||||
# Keep existing values
|
|
||||||
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
|
|
||||||
filtered_data["embedding_model"] = dataset.embedding_model
|
|
||||||
# If collection_binding_id exists, keep it too
|
|
||||||
if dataset.collection_binding_id:
|
|
||||||
filtered_data["collection_binding_id"] = dataset.collection_binding_id
|
|
||||||
# Otherwise, don't try to update embedding model settings at all
|
|
||||||
# Remove these fields from filtered_data if they exist but are None/empty
|
|
||||||
if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]:
|
|
||||||
del filtered_data["embedding_model_provider"]
|
|
||||||
if "embedding_model" in filtered_data and not filtered_data["embedding_model"]:
|
|
||||||
del filtered_data["embedding_model"]
|
|
||||||
else:
|
|
||||||
skip_embedding_update = False
|
|
||||||
try:
|
|
||||||
# Handle existing model provider
|
|
||||||
plugin_model_provider = dataset.embedding_model_provider
|
|
||||||
plugin_model_provider_str = None
|
|
||||||
if plugin_model_provider:
|
|
||||||
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
|
|
||||||
|
|
||||||
# Handle new model provider from request
|
@staticmethod
|
||||||
new_plugin_model_provider = data["embedding_model_provider"]
|
def _update_external_dataset(dataset, data, user):
|
||||||
new_plugin_model_provider_str = None
|
"""
|
||||||
if new_plugin_model_provider:
|
Update external dataset configuration.
|
||||||
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
|
|
||||||
|
|
||||||
# Only update embedding model if both values are provided and different from current
|
Args:
|
||||||
if (
|
dataset: The dataset object to update
|
||||||
plugin_model_provider_str != new_plugin_model_provider_str
|
data: Update data dictionary
|
||||||
or data["embedding_model"] != dataset.embedding_model
|
user: User performing the update
|
||||||
):
|
|
||||||
action = "update"
|
|
||||||
model_manager = ModelManager()
|
|
||||||
try:
|
|
||||||
embedding_model = model_manager.get_model_instance(
|
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
provider=data["embedding_model_provider"],
|
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
model=data["embedding_model"],
|
|
||||||
)
|
|
||||||
except ProviderTokenNotInitError:
|
|
||||||
# If we can't get the embedding model, skip updating it
|
|
||||||
# and keep the existing settings if available
|
|
||||||
if dataset.embedding_model_provider and dataset.embedding_model:
|
|
||||||
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
|
|
||||||
filtered_data["embedding_model"] = dataset.embedding_model
|
|
||||||
if dataset.collection_binding_id:
|
|
||||||
filtered_data["collection_binding_id"] = dataset.collection_binding_id
|
|
||||||
# Skip the rest of the embedding model update
|
|
||||||
skip_embedding_update = True
|
|
||||||
if not skip_embedding_update:
|
|
||||||
filtered_data["embedding_model"] = embedding_model.model
|
|
||||||
filtered_data["embedding_model_provider"] = embedding_model.provider
|
|
||||||
dataset_collection_binding = (
|
|
||||||
DatasetCollectionBindingService.get_dataset_collection_binding(
|
|
||||||
embedding_model.provider, embedding_model.model
|
|
||||||
)
|
|
||||||
)
|
|
||||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
|
||||||
except LLMBadRequestError:
|
|
||||||
raise ValueError(
|
|
||||||
"No Embedding Model available. Please configure a valid provider "
|
|
||||||
"in the Settings -> Model Provider."
|
|
||||||
)
|
|
||||||
except ProviderTokenNotInitError as ex:
|
|
||||||
raise ValueError(ex.description)
|
|
||||||
|
|
||||||
filtered_data["updated_by"] = user.id
|
Returns:
|
||||||
filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
Dataset: Updated dataset object
|
||||||
|
"""
|
||||||
|
# Update retrieval model if provided
|
||||||
|
external_retrieval_model = data.get("external_retrieval_model", None)
|
||||||
|
if external_retrieval_model:
|
||||||
|
dataset.retrieval_model = external_retrieval_model
|
||||||
|
|
||||||
# update Retrieval model
|
# Update basic dataset properties
|
||||||
filtered_data["retrieval_model"] = data["retrieval_model"]
|
dataset.name = data.get("name", dataset.name)
|
||||||
|
dataset.description = data.get("description", dataset.description)
|
||||||
|
|
||||||
db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
|
# Update permission if provided
|
||||||
|
permission = data.get("permission")
|
||||||
|
if permission:
|
||||||
|
dataset.permission = permission
|
||||||
|
|
||||||
|
# Validate and update external knowledge configuration
|
||||||
|
external_knowledge_id = data.get("external_knowledge_id", None)
|
||||||
|
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
|
||||||
|
|
||||||
|
if not external_knowledge_id:
|
||||||
|
raise ValueError("External knowledge id is required.")
|
||||||
|
if not external_knowledge_api_id:
|
||||||
|
raise ValueError("External knowledge api id is required.")
|
||||||
|
# Update metadata fields
|
||||||
|
dataset.updated_by = user.id if user else None
|
||||||
|
dataset.updated_at = datetime.datetime.utcnow()
|
||||||
|
db.session.add(dataset)
|
||||||
|
|
||||||
|
# Update external knowledge binding
|
||||||
|
DatasetService._update_external_knowledge_binding(dataset.id, external_knowledge_id, external_knowledge_api_id)
|
||||||
|
|
||||||
|
# Commit changes to database
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
if action:
|
|
||||||
deal_dataset_vector_index_task.delay(dataset_id, action)
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id):
|
||||||
|
"""
|
||||||
|
Update external knowledge binding configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: Dataset identifier
|
||||||
|
external_knowledge_id: External knowledge identifier
|
||||||
|
external_knowledge_api_id: External knowledge API identifier
|
||||||
|
"""
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
external_knowledge_binding = (
|
||||||
|
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not external_knowledge_binding:
|
||||||
|
raise ValueError("External knowledge binding not found.")
|
||||||
|
|
||||||
|
# Update binding if values have changed
|
||||||
|
if (
|
||||||
|
external_knowledge_binding.external_knowledge_id != external_knowledge_id
|
||||||
|
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
|
||||||
|
):
|
||||||
|
external_knowledge_binding.external_knowledge_id = external_knowledge_id
|
||||||
|
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
|
||||||
|
db.session.add(external_knowledge_binding)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _update_internal_dataset(dataset, data, user):
|
||||||
|
"""
|
||||||
|
Update internal dataset configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The dataset object to update
|
||||||
|
data: Update data dictionary
|
||||||
|
user: User performing the update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset: Updated dataset object
|
||||||
|
"""
|
||||||
|
# Remove external-specific fields from update data
|
||||||
|
data.pop("partial_member_list", None)
|
||||||
|
data.pop("external_knowledge_api_id", None)
|
||||||
|
data.pop("external_knowledge_id", None)
|
||||||
|
data.pop("external_retrieval_model", None)
|
||||||
|
|
||||||
|
# Filter out None values except for description field
|
||||||
|
filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
|
||||||
|
|
||||||
|
# Handle indexing technique changes and embedding model updates
|
||||||
|
action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data)
|
||||||
|
|
||||||
|
# Add metadata fields
|
||||||
|
filtered_data["updated_by"] = user.id
|
||||||
|
filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||||
|
# update Retrieval model
|
||||||
|
filtered_data["retrieval_model"] = data["retrieval_model"]
|
||||||
|
|
||||||
|
# Update dataset in database
|
||||||
|
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Trigger vector index task if indexing technique changed
|
||||||
|
if action:
|
||||||
|
deal_dataset_vector_index_task.delay(dataset.id, action)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _handle_indexing_technique_change(dataset, data, filtered_data):
|
||||||
|
"""
|
||||||
|
Handle changes in indexing technique and configure embedding models accordingly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Current dataset object
|
||||||
|
data: Update data dictionary
|
||||||
|
filtered_data: Filtered update data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Action to perform ('add', 'remove', 'update', or None)
|
||||||
|
"""
|
||||||
|
if dataset.indexing_technique != data["indexing_technique"]:
|
||||||
|
if data["indexing_technique"] == "economy":
|
||||||
|
# Remove embedding model configuration for economy mode
|
||||||
|
filtered_data["embedding_model"] = None
|
||||||
|
filtered_data["embedding_model_provider"] = None
|
||||||
|
filtered_data["collection_binding_id"] = None
|
||||||
|
return "remove"
|
||||||
|
elif data["indexing_technique"] == "high_quality":
|
||||||
|
# Configure embedding model for high quality mode
|
||||||
|
DatasetService._configure_embedding_model_for_high_quality(data, filtered_data)
|
||||||
|
return "add"
|
||||||
|
else:
|
||||||
|
# Handle embedding model updates when indexing technique remains the same
|
||||||
|
return DatasetService._handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _configure_embedding_model_for_high_quality(data, filtered_data):
|
||||||
|
"""
|
||||||
|
Configure embedding model settings for high quality indexing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Update data dictionary
|
||||||
|
filtered_data: Filtered update data to modify
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_manager = ModelManager()
|
||||||
|
embedding_model = model_manager.get_model_instance(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
provider=data["embedding_model_provider"],
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=data["embedding_model"],
|
||||||
|
)
|
||||||
|
filtered_data["embedding_model"] = embedding_model.model
|
||||||
|
filtered_data["embedding_model_provider"] = embedding_model.provider
|
||||||
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
|
embedding_model.provider, embedding_model.model
|
||||||
|
)
|
||||||
|
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||||
|
except LLMBadRequestError:
|
||||||
|
raise ValueError(
|
||||||
|
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||||
|
)
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ValueError(ex.description)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data):
|
||||||
|
"""
|
||||||
|
Handle embedding model updates when indexing technique remains the same.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Current dataset object
|
||||||
|
data: Update data dictionary
|
||||||
|
filtered_data: Filtered update data to modify
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Action to perform ('update' or None)
|
||||||
|
"""
|
||||||
|
# Skip embedding model checks if not provided in the update request
|
||||||
|
if (
|
||||||
|
"embedding_model_provider" not in data
|
||||||
|
or "embedding_model" not in data
|
||||||
|
or not data.get("embedding_model_provider")
|
||||||
|
or not data.get("embedding_model")
|
||||||
|
):
|
||||||
|
DatasetService._preserve_existing_embedding_settings(dataset, filtered_data)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return DatasetService._update_embedding_model_settings(dataset, data, filtered_data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _preserve_existing_embedding_settings(dataset, filtered_data):
|
||||||
|
"""
|
||||||
|
Preserve existing embedding model settings when not provided in update.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Current dataset object
|
||||||
|
filtered_data: Filtered update data to modify
|
||||||
|
"""
|
||||||
|
# If the dataset already has embedding model settings, use those
|
||||||
|
if dataset.embedding_model_provider and dataset.embedding_model:
|
||||||
|
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
|
||||||
|
filtered_data["embedding_model"] = dataset.embedding_model
|
||||||
|
# If collection_binding_id exists, keep it too
|
||||||
|
if dataset.collection_binding_id:
|
||||||
|
filtered_data["collection_binding_id"] = dataset.collection_binding_id
|
||||||
|
# Otherwise, don't try to update embedding model settings at all
|
||||||
|
# Remove these fields from filtered_data if they exist but are None/empty
|
||||||
|
if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]:
|
||||||
|
del filtered_data["embedding_model_provider"]
|
||||||
|
if "embedding_model" in filtered_data and not filtered_data["embedding_model"]:
|
||||||
|
del filtered_data["embedding_model"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _update_embedding_model_settings(dataset, data, filtered_data):
|
||||||
|
"""
|
||||||
|
Update embedding model settings with new values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Current dataset object
|
||||||
|
data: Update data dictionary
|
||||||
|
filtered_data: Filtered update data to modify
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Action to perform ('update' or None)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Compare current and new model provider settings
|
||||||
|
current_provider_str = (
|
||||||
|
str(ModelProviderID(dataset.embedding_model_provider)) if dataset.embedding_model_provider else None
|
||||||
|
)
|
||||||
|
new_provider_str = (
|
||||||
|
str(ModelProviderID(data["embedding_model_provider"])) if data["embedding_model_provider"] else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only update if values are different
|
||||||
|
if current_provider_str != new_provider_str or data["embedding_model"] != dataset.embedding_model:
|
||||||
|
DatasetService._apply_new_embedding_settings(dataset, data, filtered_data)
|
||||||
|
return "update"
|
||||||
|
except LLMBadRequestError:
|
||||||
|
raise ValueError(
|
||||||
|
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||||
|
)
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ValueError(ex.description)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _apply_new_embedding_settings(dataset, data, filtered_data):
|
||||||
|
"""
|
||||||
|
Apply new embedding model settings to the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Current dataset object
|
||||||
|
data: Update data dictionary
|
||||||
|
filtered_data: Filtered update data to modify
|
||||||
|
"""
|
||||||
|
model_manager = ModelManager()
|
||||||
|
try:
|
||||||
|
embedding_model = model_manager.get_model_instance(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
provider=data["embedding_model_provider"],
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=data["embedding_model"],
|
||||||
|
)
|
||||||
|
except ProviderTokenNotInitError:
|
||||||
|
# If we can't get the embedding model, preserve existing settings
|
||||||
|
if dataset.embedding_model_provider and dataset.embedding_model:
|
||||||
|
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
|
||||||
|
filtered_data["embedding_model"] = dataset.embedding_model
|
||||||
|
if dataset.collection_binding_id:
|
||||||
|
filtered_data["collection_binding_id"] = dataset.collection_binding_id
|
||||||
|
# Skip the rest of the embedding model update
|
||||||
|
return
|
||||||
|
|
||||||
|
# Apply new embedding model settings
|
||||||
|
filtered_data["embedding_model"] = embedding_model.model
|
||||||
|
filtered_data["embedding_model_provider"] = embedding_model.provider
|
||||||
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
|
embedding_model.provider, embedding_model.model
|
||||||
|
)
|
||||||
|
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_dataset(dataset_id, user):
|
def delete_dataset(dataset_id, user):
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
|
@@ -0,0 +1,826 @@
|
|||||||
|
import datetime
|
||||||
|
|
||||||
|
# Mock redis_client before importing dataset_service
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from models.dataset import Dataset, ExternalKnowledgeBindings
|
||||||
|
from services.dataset_service import DatasetService
|
||||||
|
from services.errors.account import NoPermissionError
|
||||||
|
from tests.unit_tests.conftest import redis_mock
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasetServiceUpdateDataset:
|
||||||
|
"""
|
||||||
|
Comprehensive unit tests for DatasetService.update_dataset method.
|
||||||
|
|
||||||
|
This test suite covers all supported scenarios including:
|
||||||
|
- External dataset updates
|
||||||
|
- Internal dataset updates with different indexing techniques
|
||||||
|
- Embedding model updates
|
||||||
|
- Permission checks
|
||||||
|
- Error conditions and edge cases
|
||||||
|
"""
|
||||||
|
|
||||||
|
@patch("extensions.ext_database.db.session")
|
||||||
|
@patch("services.dataset_service.DatasetService.get_dataset")
|
||||||
|
@patch("services.dataset_service.DatasetService.check_dataset_permission")
|
||||||
|
@patch("services.dataset_service.datetime")
|
||||||
|
def test_update_external_dataset_success(self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db):
|
||||||
|
"""
|
||||||
|
Test successful update of external dataset.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
1. External dataset attributes are updated correctly
|
||||||
|
2. External knowledge binding is updated when values change
|
||||||
|
3. Database changes are committed
|
||||||
|
4. Permission check is performed
|
||||||
|
"""
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
with patch.object(db.__class__, "engine", new_callable=Mock):
|
||||||
|
# Create mock dataset
|
||||||
|
mock_dataset = Mock(spec=Dataset)
|
||||||
|
mock_dataset.id = "dataset-123"
|
||||||
|
mock_dataset.provider = "external"
|
||||||
|
mock_dataset.name = "old_name"
|
||||||
|
mock_dataset.description = "old_description"
|
||||||
|
mock_dataset.retrieval_model = "old_model"
|
||||||
|
|
||||||
|
# Create mock user
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.id = "user-789"
|
||||||
|
|
||||||
|
# Create mock external knowledge binding
|
||||||
|
mock_binding = Mock(spec=ExternalKnowledgeBindings)
|
||||||
|
mock_binding.external_knowledge_id = "old_knowledge_id"
|
||||||
|
mock_binding.external_knowledge_api_id = "old_api_id"
|
||||||
|
|
||||||
|
# Set up mock return values
|
||||||
|
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||||
|
mock_datetime.datetime.now.return_value = current_time
|
||||||
|
mock_datetime.UTC = datetime.UTC
|
||||||
|
|
||||||
|
# Mock dataset retrieval
|
||||||
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
|
# Mock external knowledge binding query
|
||||||
|
with patch("services.dataset_service.Session") as mock_session:
|
||||||
|
mock_session_instance = Mock()
|
||||||
|
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||||
|
mock_session_instance.query.return_value.filter_by.return_value.first.return_value = mock_binding
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
update_data = {
|
||||||
|
"name": "new_name",
|
||||||
|
"description": "new_description",
|
||||||
|
"external_retrieval_model": "new_model",
|
||||||
|
"permission": "only_me",
|
||||||
|
"external_knowledge_id": "new_knowledge_id",
|
||||||
|
"external_knowledge_api_id": "new_api_id",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
|
||||||
|
|
||||||
|
# Verify permission check was called
|
||||||
|
mock_check_permission.assert_called_once_with(mock_dataset, mock_user)
|
||||||
|
|
||||||
|
# Verify dataset attributes were updated
|
||||||
|
assert mock_dataset.name == "new_name"
|
||||||
|
assert mock_dataset.description == "new_description"
|
||||||
|
assert mock_dataset.retrieval_model == "new_model"
|
||||||
|
|
||||||
|
# Verify external knowledge binding was updated
|
||||||
|
assert mock_binding.external_knowledge_id == "new_knowledge_id"
|
||||||
|
assert mock_binding.external_knowledge_api_id == "new_api_id"
|
||||||
|
|
||||||
|
# Verify database operations
|
||||||
|
mock_db.add.assert_any_call(mock_dataset)
|
||||||
|
mock_db.add.assert_any_call(mock_binding)
|
||||||
|
mock_db.commit.assert_called_once()
|
||||||
|
|
||||||
|
# Verify return value
|
||||||
|
assert result == mock_dataset
|
||||||
|
|
||||||
|
@patch("extensions.ext_database.db.session")
|
||||||
|
@patch("services.dataset_service.DatasetService.get_dataset")
|
||||||
|
@patch("services.dataset_service.DatasetService.check_dataset_permission")
|
||||||
|
def test_update_external_dataset_missing_knowledge_id_error(self, mock_check_permission, mock_get_dataset, mock_db):
|
||||||
|
"""
|
||||||
|
Test error when external knowledge id is missing.
|
||||||
|
"""
|
||||||
|
# Create mock dataset
|
||||||
|
mock_dataset = Mock(spec=Dataset)
|
||||||
|
mock_dataset.provider = "external"
|
||||||
|
|
||||||
|
# Create mock user
|
||||||
|
mock_user = Mock()
|
||||||
|
|
||||||
|
# Mock dataset retrieval
|
||||||
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
|
# Test data without external_knowledge_id
|
||||||
|
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
|
||||||
|
|
||||||
|
# Call the method and expect ValueError
|
||||||
|
with pytest.raises(ValueError) as context:
|
||||||
|
DatasetService.update_dataset("dataset-123", update_data, mock_user)
|
||||||
|
|
||||||
|
assert "External knowledge id is required" in str(context.value)
|
||||||
|
|
||||||
|
@patch("extensions.ext_database.db.session")
|
||||||
|
@patch("services.dataset_service.DatasetService.get_dataset")
|
||||||
|
@patch("services.dataset_service.DatasetService.check_dataset_permission")
|
||||||
|
def test_update_external_dataset_missing_api_id_error(self, mock_check_permission, mock_get_dataset, mock_db):
|
||||||
|
"""
|
||||||
|
Test error when external knowledge api id is missing.
|
||||||
|
"""
|
||||||
|
# Create mock dataset
|
||||||
|
mock_dataset = Mock(spec=Dataset)
|
||||||
|
mock_dataset.provider = "external"
|
||||||
|
|
||||||
|
# Create mock user
|
||||||
|
mock_user = Mock()
|
||||||
|
|
||||||
|
# Mock dataset retrieval
|
||||||
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
|
# Test data without external_knowledge_api_id
|
||||||
|
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
|
||||||
|
|
||||||
|
# Call the method and expect ValueError
|
||||||
|
with pytest.raises(ValueError) as context:
|
||||||
|
DatasetService.update_dataset("dataset-123", update_data, mock_user)
|
||||||
|
|
||||||
|
assert "External knowledge api id is required" in str(context.value)
|
||||||
|
|
||||||
|
@patch("extensions.ext_database.db.session")
|
||||||
|
@patch("services.dataset_service.Session")
|
||||||
|
@patch("services.dataset_service.DatasetService.get_dataset")
|
||||||
|
@patch("services.dataset_service.DatasetService.check_dataset_permission")
|
||||||
|
def test_update_external_dataset_binding_not_found_error(
|
||||||
|
self, mock_check_permission, mock_get_dataset, mock_session, mock_db
|
||||||
|
):
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
with patch.object(db.__class__, "engine", new_callable=Mock):
|
||||||
|
# Create mock dataset
|
||||||
|
mock_dataset = Mock(spec=Dataset)
|
||||||
|
mock_dataset.provider = "external"
|
||||||
|
|
||||||
|
# Create mock user
|
||||||
|
mock_user = Mock()
|
||||||
|
|
||||||
|
# Mock dataset retrieval
|
||||||
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
|
# Mock external knowledge binding query returning None
|
||||||
|
mock_session_instance = Mock()
|
||||||
|
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||||
|
mock_session_instance.query.return_value.filter_by.return_value.first.return_value = None
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
update_data = {
|
||||||
|
"name": "new_name",
|
||||||
|
"external_knowledge_id": "knowledge_id",
|
||||||
|
"external_knowledge_api_id": "api_id",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call the method and expect ValueError
|
||||||
|
with pytest.raises(ValueError) as context:
|
||||||
|
DatasetService.update_dataset("dataset-123", update_data, mock_user)
|
||||||
|
|
||||||
|
assert "External knowledge binding not found" in str(context.value)
|
||||||
|
|
||||||
|
@patch("extensions.ext_database.db.session")
|
||||||
|
@patch("services.dataset_service.DatasetService.get_dataset")
|
||||||
|
@patch("services.dataset_service.DatasetService.check_dataset_permission")
|
||||||
|
@patch("services.dataset_service.datetime")
|
||||||
|
def test_update_internal_dataset_basic_success(
|
||||||
|
self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test successful update of internal dataset with basic fields.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
1. Basic dataset attributes are updated correctly
|
||||||
|
2. Filtered data excludes None values except description
|
||||||
|
3. Timestamp fields are updated
|
||||||
|
4. Database changes are committed
|
||||||
|
"""
|
||||||
|
# Create mock dataset
|
||||||
|
mock_dataset = Mock(spec=Dataset)
|
||||||
|
mock_dataset.id = "dataset-123"
|
||||||
|
mock_dataset.provider = "vendor"
|
||||||
|
mock_dataset.name = "old_name"
|
||||||
|
mock_dataset.description = "old_description"
|
||||||
|
mock_dataset.indexing_technique = "high_quality"
|
||||||
|
mock_dataset.retrieval_model = "old_model"
|
||||||
|
mock_dataset.embedding_model_provider = "openai"
|
||||||
|
mock_dataset.embedding_model = "text-embedding-ada-002"
|
||||||
|
mock_dataset.collection_binding_id = "binding-123"
|
||||||
|
|
||||||
|
# Create mock user
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.id = "user-789"
|
||||||
|
|
||||||
|
# Set up mock return values
|
||||||
|
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||||
|
mock_datetime.datetime.now.return_value = current_time
|
||||||
|
mock_datetime.UTC = datetime.UTC
|
||||||
|
|
||||||
|
# Mock dataset retrieval
|
||||||
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
update_data = {
|
||||||
|
"name": "new_name",
|
||||||
|
"description": "new_description",
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
"embedding_model_provider": "openai",
|
||||||
|
"embedding_model": "text-embedding-ada-002",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
|
||||||
|
|
||||||
|
# Verify permission check was called
|
||||||
|
mock_check_permission.assert_called_once_with(mock_dataset, mock_user)
|
||||||
|
|
||||||
|
# Verify database update was called with correct filtered data
|
||||||
|
expected_filtered_data = {
|
||||||
|
"name": "new_name",
|
||||||
|
"description": "new_description",
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
"embedding_model_provider": "openai",
|
||||||
|
"embedding_model": "text-embedding-ada-002",
|
||||||
|
"updated_by": mock_user.id,
|
||||||
|
"updated_at": current_time.replace(tzinfo=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data)
|
||||||
|
mock_db.commit.assert_called_once()
|
||||||
|
|
||||||
|
# Verify return value
|
||||||
|
assert result == mock_dataset
|
||||||
|
|
||||||
|
@patch("extensions.ext_database.db.session")
|
||||||
|
@patch("services.dataset_service.deal_dataset_vector_index_task")
|
||||||
|
@patch("services.dataset_service.DatasetService.get_dataset")
|
||||||
|
@patch("services.dataset_service.DatasetService.check_dataset_permission")
|
||||||
|
@patch("services.dataset_service.datetime")
|
||||||
|
def test_update_internal_dataset_indexing_technique_to_economy(
|
||||||
|
self, mock_datetime, mock_check_permission, mock_get_dataset, mock_task, mock_db
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test updating internal dataset indexing technique to economy.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
1. Embedding model fields are cleared when switching to economy
|
||||||
|
2. Vector index task is triggered with 'remove' action
|
||||||
|
"""
|
||||||
|
# Create mock dataset
|
||||||
|
mock_dataset = Mock(spec=Dataset)
|
||||||
|
mock_dataset.id = "dataset-123"
|
||||||
|
mock_dataset.provider = "vendor"
|
||||||
|
mock_dataset.indexing_technique = "high_quality"
|
||||||
|
|
||||||
|
# Create mock user
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.id = "user-789"
|
||||||
|
|
||||||
|
# Set up mock return values
|
||||||
|
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||||
|
mock_datetime.datetime.now.return_value = current_time
|
||||||
|
mock_datetime.UTC = datetime.UTC
|
||||||
|
|
||||||
|
# Mock dataset retrieval
|
||||||
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
|
||||||
|
|
||||||
|
# Verify database update was called with embedding model fields cleared
|
||||||
|
expected_filtered_data = {
|
||||||
|
"indexing_technique": "economy",
|
||||||
|
"embedding_model": None,
|
||||||
|
"embedding_model_provider": None,
|
||||||
|
"collection_binding_id": None,
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
"updated_by": mock_user.id,
|
||||||
|
"updated_at": current_time.replace(tzinfo=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data)
|
||||||
|
mock_db.commit.assert_called_once()
|
||||||
|
|
||||||
|
# Verify vector index task was triggered
|
||||||
|
mock_task.delay.assert_called_once_with("dataset-123", "remove")
|
||||||
|
|
||||||
|
# Verify return value
|
||||||
|
assert result == mock_dataset
|
||||||
|
|
||||||
|
@patch("extensions.ext_database.db.session")
|
||||||
|
@patch("services.dataset_service.DatasetService.get_dataset")
|
||||||
|
@patch("services.dataset_service.DatasetService.check_dataset_permission")
|
||||||
|
def test_update_dataset_not_found_error(self, mock_check_permission, mock_get_dataset, mock_db):
|
||||||
|
"""
|
||||||
|
Test error when dataset is not found.
|
||||||
|
"""
|
||||||
|
# Create mock user
|
||||||
|
mock_user = Mock()
|
||||||
|
|
||||||
|
# Mock dataset retrieval returning None
|
||||||
|
mock_get_dataset.return_value = None
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
update_data = {"name": "new_name"}
|
||||||
|
|
||||||
|
# Call the method and expect ValueError
|
||||||
|
with pytest.raises(ValueError) as context:
|
||||||
|
DatasetService.update_dataset("dataset-123", update_data, mock_user)
|
||||||
|
|
||||||
|
assert "Dataset not found" in str(context.value)
|
||||||
|
|
||||||
|
@patch("extensions.ext_database.db.session")
|
||||||
|
@patch("services.dataset_service.DatasetService.get_dataset")
|
||||||
|
@patch("services.dataset_service.DatasetService.check_dataset_permission")
|
||||||
|
def test_update_dataset_permission_error(self, mock_check_permission, mock_get_dataset, mock_db):
|
||||||
|
"""
|
||||||
|
Test error when user doesn't have permission.
|
||||||
|
"""
|
||||||
|
# Create mock dataset
|
||||||
|
mock_dataset = Mock(spec=Dataset)
|
||||||
|
mock_dataset.id = "dataset-123"
|
||||||
|
|
||||||
|
# Create mock user
|
||||||
|
mock_user = Mock()
|
||||||
|
|
||||||
|
# Mock dataset retrieval
|
||||||
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
|
# Mock permission check to raise error
|
||||||
|
mock_check_permission.side_effect = NoPermissionError("No permission")
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
update_data = {"name": "new_name"}
|
||||||
|
|
||||||
|
# Call the method and expect NoPermissionError
|
||||||
|
with pytest.raises(NoPermissionError):
|
||||||
|
DatasetService.update_dataset("dataset-123", update_data, mock_user)
|
||||||
|
|
||||||
|
@patch("extensions.ext_database.db.session")
|
||||||
|
@patch("services.dataset_service.DatasetService.get_dataset")
|
||||||
|
@patch("services.dataset_service.DatasetService.check_dataset_permission")
|
||||||
|
@patch("services.dataset_service.datetime")
|
||||||
|
def test_update_internal_dataset_keep_existing_embedding_model(
|
||||||
|
self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test updating internal dataset without changing embedding model.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
1. Existing embedding model settings are preserved when not provided in update
|
||||||
|
2. No vector index task is triggered
|
||||||
|
"""
|
||||||
|
# Create mock dataset
|
||||||
|
mock_dataset = Mock(spec=Dataset)
|
||||||
|
mock_dataset.id = "dataset-123"
|
||||||
|
mock_dataset.provider = "vendor"
|
||||||
|
mock_dataset.indexing_technique = "high_quality"
|
||||||
|
mock_dataset.embedding_model_provider = "openai"
|
||||||
|
mock_dataset.embedding_model = "text-embedding-ada-002"
|
||||||
|
mock_dataset.collection_binding_id = "binding-123"
|
||||||
|
|
||||||
|
# Create mock user
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.id = "user-789"
|
||||||
|
|
||||||
|
# Set up mock return values
|
||||||
|
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||||
|
mock_datetime.datetime.now.return_value = current_time
|
||||||
|
mock_datetime.UTC = datetime.UTC
|
||||||
|
|
||||||
|
# Mock dataset retrieval
|
||||||
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
|
# Test data without embedding model fields
|
||||||
|
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
|
||||||
|
|
||||||
|
# Verify database update was called with existing embedding model preserved
|
||||||
|
expected_filtered_data = {
|
||||||
|
"name": "new_name",
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"embedding_model_provider": "openai",
|
||||||
|
"embedding_model": "text-embedding-ada-002",
|
||||||
|
"collection_binding_id": "binding-123",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
"updated_by": mock_user.id,
|
||||||
|
"updated_at": current_time.replace(tzinfo=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data)
|
||||||
|
mock_db.commit.assert_called_once()
|
||||||
|
|
||||||
|
# Verify return value
|
||||||
|
assert result == mock_dataset
|
||||||
|
|
||||||
|
@patch("extensions.ext_database.db.session")
|
||||||
|
@patch("services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding")
|
||||||
|
@patch("services.dataset_service.ModelManager")
|
||||||
|
@patch("services.dataset_service.deal_dataset_vector_index_task")
|
||||||
|
@patch("services.dataset_service.DatasetService.get_dataset")
|
||||||
|
@patch("services.dataset_service.DatasetService.check_dataset_permission")
|
||||||
|
@patch("services.dataset_service.datetime")
|
||||||
|
def test_update_internal_dataset_indexing_technique_to_high_quality(
|
||||||
|
self,
|
||||||
|
mock_datetime,
|
||||||
|
mock_check_permission,
|
||||||
|
mock_get_dataset,
|
||||||
|
mock_task,
|
||||||
|
mock_model_manager,
|
||||||
|
mock_collection_binding,
|
||||||
|
mock_db,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test updating internal dataset indexing technique to high_quality.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
1. Embedding model is validated and set
|
||||||
|
2. Collection binding is retrieved
|
||||||
|
3. Vector index task is triggered with 'add' action
|
||||||
|
"""
|
||||||
|
# Create mock dataset
|
||||||
|
mock_dataset = Mock(spec=Dataset)
|
||||||
|
mock_dataset.id = "dataset-123"
|
||||||
|
mock_dataset.provider = "vendor"
|
||||||
|
mock_dataset.indexing_technique = "economy"
|
||||||
|
|
||||||
|
# Create mock user
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.id = "user-789"
|
||||||
|
|
||||||
|
# Set up mock return values
|
||||||
|
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||||
|
mock_datetime.datetime.now.return_value = current_time
|
||||||
|
mock_datetime.UTC = datetime.UTC
|
||||||
|
|
||||||
|
# Mock dataset retrieval
|
||||||
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
|
# Mock embedding model
|
||||||
|
mock_embedding_model = Mock()
|
||||||
|
mock_embedding_model.model = "text-embedding-ada-002"
|
||||||
|
mock_embedding_model.provider = "openai"
|
||||||
|
|
||||||
|
# Mock collection binding
|
||||||
|
mock_collection_binding_instance = Mock()
|
||||||
|
mock_collection_binding_instance.id = "binding-456"
|
||||||
|
|
||||||
|
# Mock model manager
|
||||||
|
mock_model_manager_instance = Mock()
|
||||||
|
mock_model_manager_instance.get_model_instance.return_value = mock_embedding_model
|
||||||
|
mock_model_manager.return_value = mock_model_manager_instance
|
||||||
|
|
||||||
|
# Mock collection binding service
|
||||||
|
mock_collection_binding.return_value = mock_collection_binding_instance
|
||||||
|
|
||||||
|
# Mock current_user
|
||||||
|
mock_current_user = Mock()
|
||||||
|
mock_current_user.current_tenant_id = "tenant-123"
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
update_data = {
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"embedding_model_provider": "openai",
|
||||||
|
"embedding_model": "text-embedding-ada-002",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call the method with current_user mock
|
||||||
|
with patch("services.dataset_service.current_user", mock_current_user):
|
||||||
|
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
|
||||||
|
|
||||||
|
# Verify embedding model was validated
|
||||||
|
mock_model_manager_instance.get_model_instance.assert_called_once_with(
|
||||||
|
tenant_id=mock_current_user.current_tenant_id,
|
||||||
|
provider="openai",
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify collection binding was retrieved
|
||||||
|
mock_collection_binding.assert_called_once_with("openai", "text-embedding-ada-002")
|
||||||
|
|
||||||
|
# Verify database update was called with correct data
|
||||||
|
expected_filtered_data = {
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"embedding_model": "text-embedding-ada-002",
|
||||||
|
"embedding_model_provider": "openai",
|
||||||
|
"collection_binding_id": "binding-456",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
"updated_by": mock_user.id,
|
||||||
|
"updated_at": current_time.replace(tzinfo=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data)
|
||||||
|
mock_db.commit.assert_called_once()
|
||||||
|
|
||||||
|
# Verify vector index task was triggered
|
||||||
|
mock_task.delay.assert_called_once_with("dataset-123", "add")
|
||||||
|
|
||||||
|
# Verify return value
|
||||||
|
assert result == mock_dataset
|
||||||
|
|
||||||
|
@patch("extensions.ext_database.db.session")
|
||||||
|
@patch("services.dataset_service.DatasetService.get_dataset")
|
||||||
|
@patch("services.dataset_service.DatasetService.check_dataset_permission")
|
||||||
|
def test_update_internal_dataset_embedding_model_error(self, mock_check_permission, mock_get_dataset, mock_db):
|
||||||
|
"""
|
||||||
|
Test error when embedding model is not available.
|
||||||
|
"""
|
||||||
|
# Create mock dataset
|
||||||
|
mock_dataset = Mock(spec=Dataset)
|
||||||
|
mock_dataset.id = "dataset-123"
|
||||||
|
mock_dataset.provider = "vendor"
|
||||||
|
mock_dataset.indexing_technique = "economy"
|
||||||
|
|
||||||
|
# Create mock user
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.id = "user-789"
|
||||||
|
|
||||||
|
# Mock dataset retrieval
|
||||||
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
|
# Mock current_user
|
||||||
|
mock_current_user = Mock()
|
||||||
|
mock_current_user.current_tenant_id = "tenant-123"
|
||||||
|
|
||||||
|
# Mock model manager to raise error
|
||||||
|
with (
|
||||||
|
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
||||||
|
patch("services.dataset_service.current_user", mock_current_user),
|
||||||
|
):
|
||||||
|
mock_model_manager_instance = Mock()
|
||||||
|
mock_model_manager_instance.get_model_instance.side_effect = Exception("No Embedding Model available")
|
||||||
|
mock_model_manager.return_value = mock_model_manager_instance
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
update_data = {
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"embedding_model_provider": "invalid_provider",
|
||||||
|
"embedding_model": "invalid_model",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call the method and expect ValueError
|
||||||
|
with pytest.raises(Exception) as context:
|
||||||
|
DatasetService.update_dataset("dataset-123", update_data, mock_user)
|
||||||
|
|
||||||
|
assert "No Embedding Model available".lower() in str(context.value).lower()
|
||||||
|
|
||||||
|
@patch("extensions.ext_database.db.session")
|
||||||
|
@patch("services.dataset_service.DatasetService.get_dataset")
|
||||||
|
@patch("services.dataset_service.DatasetService.check_dataset_permission")
|
||||||
|
@patch("services.dataset_service.datetime")
|
||||||
|
def test_update_internal_dataset_filter_none_values(
|
||||||
|
self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that None values are filtered out except for description field.
|
||||||
|
"""
|
||||||
|
# Create mock dataset
|
||||||
|
mock_dataset = Mock(spec=Dataset)
|
||||||
|
mock_dataset.id = "dataset-123"
|
||||||
|
mock_dataset.provider = "vendor"
|
||||||
|
mock_dataset.indexing_technique = "high_quality"
|
||||||
|
|
||||||
|
# Create mock user
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.id = "user-789"
|
||||||
|
|
||||||
|
# Mock dataset retrieval
|
||||||
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
|
# Test data with None values
|
||||||
|
update_data = {
|
||||||
|
"name": "new_name",
|
||||||
|
"description": None, # Should be included
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
"embedding_model_provider": None, # Should be filtered out
|
||||||
|
"embedding_model": None, # Should be filtered out
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
|
||||||
|
|
||||||
|
# Verify database update was called with filtered data
|
||||||
|
expected_filtered_data = {
|
||||||
|
"name": "new_name",
|
||||||
|
"description": None, # Description should be included even if None
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
"updated_by": mock_user.id,
|
||||||
|
"updated_at": mock_db.query.return_value.filter_by.return_value.update.call_args[0][0]["updated_at"],
|
||||||
|
}
|
||||||
|
|
||||||
|
actual_call_args = mock_db.query.return_value.filter_by.return_value.update.call_args[0][0]
|
||||||
|
# Remove timestamp for comparison as it's dynamic
|
||||||
|
del actual_call_args["updated_at"]
|
||||||
|
del expected_filtered_data["updated_at"]
|
||||||
|
|
||||||
|
del actual_call_args["collection_binding_id"]
|
||||||
|
del actual_call_args["embedding_model"]
|
||||||
|
del actual_call_args["embedding_model_provider"]
|
||||||
|
|
||||||
|
assert actual_call_args == expected_filtered_data
|
||||||
|
|
||||||
|
# Verify return value
|
||||||
|
assert result == mock_dataset
|
||||||
|
|
||||||
|
@patch("extensions.ext_database.db.session")
|
||||||
|
@patch("services.dataset_service.deal_dataset_vector_index_task")
|
||||||
|
@patch("services.dataset_service.DatasetService.get_dataset")
|
||||||
|
@patch("services.dataset_service.DatasetService.check_dataset_permission")
|
||||||
|
@patch("services.dataset_service.datetime")
|
||||||
|
def test_update_internal_dataset_embedding_model_update(
|
||||||
|
self, mock_datetime, mock_check_permission, mock_get_dataset, mock_task, mock_db
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test updating internal dataset with new embedding model.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
1. Embedding model is updated when different from current
|
||||||
|
2. Vector index task is triggered with 'update' action
|
||||||
|
"""
|
||||||
|
# Create mock dataset
|
||||||
|
mock_dataset = Mock(spec=Dataset)
|
||||||
|
mock_dataset.id = "dataset-123"
|
||||||
|
mock_dataset.provider = "vendor"
|
||||||
|
mock_dataset.indexing_technique = "high_quality"
|
||||||
|
mock_dataset.embedding_model_provider = "openai"
|
||||||
|
mock_dataset.embedding_model = "text-embedding-ada-002"
|
||||||
|
|
||||||
|
# Create mock user
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.id = "user-789"
|
||||||
|
|
||||||
|
# Set up mock return values
|
||||||
|
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||||
|
mock_datetime.datetime.now.return_value = current_time
|
||||||
|
mock_datetime.UTC = datetime.UTC
|
||||||
|
|
||||||
|
# Mock dataset retrieval
|
||||||
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
|
# Mock embedding model
|
||||||
|
mock_embedding_model = Mock()
|
||||||
|
mock_embedding_model.model = "text-embedding-3-small"
|
||||||
|
mock_embedding_model.provider = "openai"
|
||||||
|
|
||||||
|
# Mock collection binding
|
||||||
|
mock_collection_binding_instance = Mock()
|
||||||
|
mock_collection_binding_instance.id = "binding-789"
|
||||||
|
|
||||||
|
# Mock current_user
|
||||||
|
mock_current_user = Mock()
|
||||||
|
mock_current_user.current_tenant_id = "tenant-123"
|
||||||
|
|
||||||
|
# Mock model manager
|
||||||
|
with patch("services.dataset_service.ModelManager") as mock_model_manager:
|
||||||
|
mock_model_manager_instance = Mock()
|
||||||
|
mock_model_manager_instance.get_model_instance.return_value = mock_embedding_model
|
||||||
|
mock_model_manager.return_value = mock_model_manager_instance
|
||||||
|
|
||||||
|
# Mock collection binding service
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
|
||||||
|
) as mock_collection_binding,
|
||||||
|
patch("services.dataset_service.current_user", mock_current_user),
|
||||||
|
):
|
||||||
|
mock_collection_binding.return_value = mock_collection_binding_instance
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
update_data = {
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"embedding_model_provider": "openai",
|
||||||
|
"embedding_model": "text-embedding-3-small",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
|
||||||
|
|
||||||
|
# Verify embedding model was validated
|
||||||
|
mock_model_manager_instance.get_model_instance.assert_called_once_with(
|
||||||
|
tenant_id=mock_current_user.current_tenant_id,
|
||||||
|
provider="openai",
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model="text-embedding-3-small",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify collection binding was retrieved
|
||||||
|
mock_collection_binding.assert_called_once_with("openai", "text-embedding-3-small")
|
||||||
|
|
||||||
|
# Verify database update was called with correct data
|
||||||
|
expected_filtered_data = {
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"embedding_model": "text-embedding-3-small",
|
||||||
|
"embedding_model_provider": "openai",
|
||||||
|
"collection_binding_id": "binding-789",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
"updated_by": mock_user.id,
|
||||||
|
"updated_at": current_time.replace(tzinfo=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data)
|
||||||
|
mock_db.commit.assert_called_once()
|
||||||
|
|
||||||
|
# Verify vector index task was triggered
|
||||||
|
mock_task.delay.assert_called_once_with("dataset-123", "update")
|
||||||
|
|
||||||
|
# Verify return value
|
||||||
|
assert result == mock_dataset
|
||||||
|
|
||||||
|
@patch("extensions.ext_database.db.session")
|
||||||
|
@patch("services.dataset_service.DatasetService.get_dataset")
|
||||||
|
@patch("services.dataset_service.DatasetService.check_dataset_permission")
|
||||||
|
@patch("services.dataset_service.datetime")
|
||||||
|
def test_update_internal_dataset_no_indexing_technique_change(
|
||||||
|
self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test updating internal dataset without changing indexing technique.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
1. No vector index task is triggered when indexing technique doesn't change
|
||||||
|
2. Database update is performed normally
|
||||||
|
"""
|
||||||
|
# Create mock dataset
|
||||||
|
mock_dataset = Mock(spec=Dataset)
|
||||||
|
mock_dataset.id = "dataset-123"
|
||||||
|
mock_dataset.provider = "vendor"
|
||||||
|
mock_dataset.indexing_technique = "high_quality"
|
||||||
|
mock_dataset.embedding_model_provider = "openai"
|
||||||
|
mock_dataset.embedding_model = "text-embedding-ada-002"
|
||||||
|
mock_dataset.collection_binding_id = "binding-123"
|
||||||
|
|
||||||
|
# Create mock user
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.id = "user-789"
|
||||||
|
|
||||||
|
# Set up mock return values
|
||||||
|
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||||
|
mock_datetime.datetime.now.return_value = current_time
|
||||||
|
mock_datetime.UTC = datetime.UTC
|
||||||
|
|
||||||
|
# Mock dataset retrieval
|
||||||
|
mock_get_dataset.return_value = mock_dataset
|
||||||
|
|
||||||
|
# Test data with same indexing technique
|
||||||
|
update_data = {
|
||||||
|
"name": "new_name",
|
||||||
|
"indexing_technique": "high_quality", # Same as current
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = DatasetService.update_dataset("dataset-123", update_data, mock_user)
|
||||||
|
|
||||||
|
# Verify database update was called with correct data
|
||||||
|
expected_filtered_data = {
|
||||||
|
"name": "new_name",
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"embedding_model_provider": "openai",
|
||||||
|
"embedding_model": "text-embedding-ada-002",
|
||||||
|
"collection_binding_id": "binding-123",
|
||||||
|
"retrieval_model": "new_model",
|
||||||
|
"updated_by": mock_user.id,
|
||||||
|
"updated_at": current_time.replace(tzinfo=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data)
|
||||||
|
mock_db.commit.assert_called_once()
|
||||||
|
|
||||||
|
# Verify no vector index task was triggered
|
||||||
|
mock_db.query.return_value.filter_by.return_value.update.assert_called_once()
|
||||||
|
|
||||||
|
# Verify return value
|
||||||
|
assert result == mock_dataset
|
Reference in New Issue
Block a user