Refactor update dataset (fix #21401) (#21402)

Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
NeatGuyCoding
2025-06-25 11:44:35 +08:00
committed by GitHub
parent 819c02f1f5
commit 94f8e48647
3 changed files with 1133 additions and 153 deletions

View File

@@ -280,174 +280,328 @@ class DatasetService:
@staticmethod
def update_dataset(dataset_id, data, user):
"""
Update dataset configuration and settings.
Args:
dataset_id: The unique identifier of the dataset to update
data: Dictionary containing the update data
user: The user performing the update operation
Returns:
Dataset: The updated dataset object
Raises:
ValueError: If dataset not found or validation fails
NoPermissionError: If user lacks permission to update the dataset
"""
# Retrieve and validate dataset existence
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise ValueError("Dataset not found")
# Verify user has permission to update this dataset
DatasetService.check_dataset_permission(dataset, user)
# Handle external dataset updates
if dataset.provider == "external":
external_retrieval_model = data.get("external_retrieval_model", None)
if external_retrieval_model:
dataset.retrieval_model = external_retrieval_model
dataset.name = data.get("name", dataset.name)
dataset.description = data.get("description", "")
permission = data.get("permission")
if permission:
dataset.permission = permission
external_knowledge_id = data.get("external_knowledge_id", None)
db.session.add(dataset)
if not external_knowledge_id:
raise ValueError("External knowledge id is required.")
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
if not external_knowledge_api_id:
raise ValueError("External knowledge api id is required.")
with Session(db.engine) as session:
external_knowledge_binding = (
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
)
if not external_knowledge_binding:
raise ValueError("External knowledge binding not found.")
if (
external_knowledge_binding.external_knowledge_id != external_knowledge_id
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
):
external_knowledge_binding.external_knowledge_id = external_knowledge_id
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
db.session.add(external_knowledge_binding)
db.session.commit()
return DatasetService._update_external_dataset(dataset, data, user)
else:
data.pop("partial_member_list", None)
data.pop("external_knowledge_api_id", None)
data.pop("external_knowledge_id", None)
data.pop("external_retrieval_model", None)
filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
action = None
if dataset.indexing_technique != data["indexing_technique"]:
# if update indexing_technique
if data["indexing_technique"] == "economy":
action = "remove"
filtered_data["embedding_model"] = None
filtered_data["embedding_model_provider"] = None
filtered_data["collection_binding_id"] = None
elif data["indexing_technique"] == "high_quality":
action = "add"
# get embedding model setting
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=data["embedding_model"],
)
filtered_data["embedding_model"] = embedding_model.model
filtered_data["embedding_model_provider"] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
else:
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
# Skip embedding model checks if not provided in the update request
if (
"embedding_model_provider" not in data
or "embedding_model" not in data
or not data.get("embedding_model_provider")
or not data.get("embedding_model")
):
# If the dataset already has embedding model settings, use those
if dataset.embedding_model_provider and dataset.embedding_model:
# Keep existing values
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
filtered_data["embedding_model"] = dataset.embedding_model
# If collection_binding_id exists, keep it too
if dataset.collection_binding_id:
filtered_data["collection_binding_id"] = dataset.collection_binding_id
# Otherwise, don't try to update embedding model settings at all
# Remove these fields from filtered_data if they exist but are None/empty
if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]:
del filtered_data["embedding_model_provider"]
if "embedding_model" in filtered_data and not filtered_data["embedding_model"]:
del filtered_data["embedding_model"]
else:
skip_embedding_update = False
try:
# Handle existing model provider
plugin_model_provider = dataset.embedding_model_provider
plugin_model_provider_str = None
if plugin_model_provider:
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
return DatasetService._update_internal_dataset(dataset, data, user)
# Handle new model provider from request
new_plugin_model_provider = data["embedding_model_provider"]
new_plugin_model_provider_str = None
if new_plugin_model_provider:
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
@staticmethod
def _update_external_dataset(dataset, data, user):
"""
Update external dataset configuration.
# Only update embedding model if both values are provided and different from current
if (
plugin_model_provider_str != new_plugin_model_provider_str
or data["embedding_model"] != dataset.embedding_model
):
action = "update"
model_manager = ModelManager()
try:
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=data["embedding_model"],
)
except ProviderTokenNotInitError:
# If we can't get the embedding model, skip updating it
# and keep the existing settings if available
if dataset.embedding_model_provider and dataset.embedding_model:
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
filtered_data["embedding_model"] = dataset.embedding_model
if dataset.collection_binding_id:
filtered_data["collection_binding_id"] = dataset.collection_binding_id
# Skip the rest of the embedding model update
skip_embedding_update = True
if not skip_embedding_update:
filtered_data["embedding_model"] = embedding_model.model
filtered_data["embedding_model_provider"] = embedding_model.provider
dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
Args:
dataset: The dataset object to update
data: Update data dictionary
user: User performing the update
filtered_data["updated_by"] = user.id
filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
Returns:
Dataset: Updated dataset object
"""
# Update retrieval model if provided
external_retrieval_model = data.get("external_retrieval_model", None)
if external_retrieval_model:
dataset.retrieval_model = external_retrieval_model
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
# Update basic dataset properties
dataset.name = data.get("name", dataset.name)
dataset.description = data.get("description", dataset.description)
db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
# Update permission if provided
permission = data.get("permission")
if permission:
dataset.permission = permission
# Validate and update external knowledge configuration
external_knowledge_id = data.get("external_knowledge_id", None)
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
if not external_knowledge_id:
raise ValueError("External knowledge id is required.")
if not external_knowledge_api_id:
raise ValueError("External knowledge api id is required.")
# Update metadata fields
dataset.updated_by = user.id if user else None
dataset.updated_at = datetime.datetime.utcnow()
db.session.add(dataset)
# Update external knowledge binding
DatasetService._update_external_knowledge_binding(dataset.id, external_knowledge_id, external_knowledge_api_id)
# Commit changes to database
db.session.commit()
db.session.commit()
if action:
deal_dataset_vector_index_task.delay(dataset_id, action)
return dataset
@staticmethod
def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id):
"""
Update external knowledge binding configuration.
Args:
dataset_id: Dataset identifier
external_knowledge_id: External knowledge identifier
external_knowledge_api_id: External knowledge API identifier
"""
with Session(db.engine) as session:
external_knowledge_binding = (
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
)
if not external_knowledge_binding:
raise ValueError("External knowledge binding not found.")
# Update binding if values have changed
if (
external_knowledge_binding.external_knowledge_id != external_knowledge_id
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
):
external_knowledge_binding.external_knowledge_id = external_knowledge_id
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
db.session.add(external_knowledge_binding)
@staticmethod
def _update_internal_dataset(dataset, data, user):
"""
Update internal dataset configuration.
Args:
dataset: The dataset object to update
data: Update data dictionary
user: User performing the update
Returns:
Dataset: Updated dataset object
"""
# Remove external-specific fields from update data
data.pop("partial_member_list", None)
data.pop("external_knowledge_api_id", None)
data.pop("external_knowledge_id", None)
data.pop("external_retrieval_model", None)
# Filter out None values except for description field
filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
# Handle indexing technique changes and embedding model updates
action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data)
# Add metadata fields
filtered_data["updated_by"] = user.id
filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
# Update dataset in database
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
db.session.commit()
# Trigger vector index task if indexing technique changed
if action:
deal_dataset_vector_index_task.delay(dataset.id, action)
return dataset
@staticmethod
def _handle_indexing_technique_change(dataset, data, filtered_data):
"""
Handle changes in indexing technique and configure embedding models accordingly.
Args:
dataset: Current dataset object
data: Update data dictionary
filtered_data: Filtered update data
Returns:
str: Action to perform ('add', 'remove', 'update', or None)
"""
if dataset.indexing_technique != data["indexing_technique"]:
if data["indexing_technique"] == "economy":
# Remove embedding model configuration for economy mode
filtered_data["embedding_model"] = None
filtered_data["embedding_model_provider"] = None
filtered_data["collection_binding_id"] = None
return "remove"
elif data["indexing_technique"] == "high_quality":
# Configure embedding model for high quality mode
DatasetService._configure_embedding_model_for_high_quality(data, filtered_data)
return "add"
else:
# Handle embedding model updates when indexing technique remains the same
return DatasetService._handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data)
return None
@staticmethod
def _configure_embedding_model_for_high_quality(data, filtered_data):
"""
Configure embedding model settings for high quality indexing.
Args:
data: Update data dictionary
filtered_data: Filtered update data to modify
"""
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=data["embedding_model"],
)
filtered_data["embedding_model"] = embedding_model.model
filtered_data["embedding_model_provider"] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
@staticmethod
def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data):
"""
Handle embedding model updates when indexing technique remains the same.
Args:
dataset: Current dataset object
data: Update data dictionary
filtered_data: Filtered update data to modify
Returns:
str: Action to perform ('update' or None)
"""
# Skip embedding model checks if not provided in the update request
if (
"embedding_model_provider" not in data
or "embedding_model" not in data
or not data.get("embedding_model_provider")
or not data.get("embedding_model")
):
DatasetService._preserve_existing_embedding_settings(dataset, filtered_data)
return None
else:
return DatasetService._update_embedding_model_settings(dataset, data, filtered_data)
@staticmethod
def _preserve_existing_embedding_settings(dataset, filtered_data):
"""
Preserve existing embedding model settings when not provided in update.
Args:
dataset: Current dataset object
filtered_data: Filtered update data to modify
"""
# If the dataset already has embedding model settings, use those
if dataset.embedding_model_provider and dataset.embedding_model:
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
filtered_data["embedding_model"] = dataset.embedding_model
# If collection_binding_id exists, keep it too
if dataset.collection_binding_id:
filtered_data["collection_binding_id"] = dataset.collection_binding_id
# Otherwise, don't try to update embedding model settings at all
# Remove these fields from filtered_data if they exist but are None/empty
if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]:
del filtered_data["embedding_model_provider"]
if "embedding_model" in filtered_data and not filtered_data["embedding_model"]:
del filtered_data["embedding_model"]
@staticmethod
def _update_embedding_model_settings(dataset, data, filtered_data):
"""
Update embedding model settings with new values.
Args:
dataset: Current dataset object
data: Update data dictionary
filtered_data: Filtered update data to modify
Returns:
str: Action to perform ('update' or None)
"""
try:
# Compare current and new model provider settings
current_provider_str = (
str(ModelProviderID(dataset.embedding_model_provider)) if dataset.embedding_model_provider else None
)
new_provider_str = (
str(ModelProviderID(data["embedding_model_provider"])) if data["embedding_model_provider"] else None
)
# Only update if values are different
if current_provider_str != new_provider_str or data["embedding_model"] != dataset.embedding_model:
DatasetService._apply_new_embedding_settings(dataset, data, filtered_data)
return "update"
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
return None
@staticmethod
def _apply_new_embedding_settings(dataset, data, filtered_data):
"""
Apply new embedding model settings to the dataset.
Args:
dataset: Current dataset object
data: Update data dictionary
filtered_data: Filtered update data to modify
"""
model_manager = ModelManager()
try:
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=data["embedding_model"],
)
except ProviderTokenNotInitError:
# If we can't get the embedding model, preserve existing settings
if dataset.embedding_model_provider and dataset.embedding_model:
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
filtered_data["embedding_model"] = dataset.embedding_model
if dataset.collection_binding_id:
filtered_data["collection_binding_id"] = dataset.collection_binding_id
# Skip the rest of the embedding model update
return
# Apply new embedding model settings
filtered_data["embedding_model"] = embedding_model.model
filtered_data["embedding_model_provider"] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
@staticmethod
def delete_dataset(dataset_id, user):
dataset = DatasetService.get_dataset(dataset_id)