fix(plugin/migrations) refactor data migration to use specific provider ID classes. (#21187)

This commit is contained in:
Yeuoly
2025-06-19 13:02:39 +08:00
committed by GitHub
parent 2c04a16eaa
commit 2020a31785

View File

@@ -3,7 +3,7 @@ import logging
import click
from core.entities import DEFAULT_PLUGIN_ID
from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
from models.engine import db
logger = logging.getLogger(__name__)
@@ -12,17 +12,17 @@ logger = logging.getLogger(__name__)
class PluginDataMigration:
@classmethod
def migrate(cls) -> None:
cls.migrate_db_records("providers", "provider_name") # large table
cls.migrate_db_records("provider_models", "provider_name")
cls.migrate_db_records("provider_orders", "provider_name")
cls.migrate_db_records("tenant_default_models", "provider_name")
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
cls.migrate_db_records("provider_model_settings", "provider_name")
cls.migrate_db_records("load_balancing_model_configs", "provider_name")
cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID)
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID)
cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID)
cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID)
cls.migrate_datasets()
cls.migrate_db_records("embeddings", "provider_name") # large table
cls.migrate_db_records("dataset_collection_bindings", "provider_name")
cls.migrate_db_records("tool_builtin_providers", "provider")
cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table
cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID)
cls.migrate_db_records("tool_builtin_providers", "provider_name", ToolProviderID)
@classmethod
def migrate_datasets(cls) -> None:
@@ -66,9 +66,10 @@ limit 1000"""
fg="white",
)
)
retrieval_model["reranking_model"]["reranking_provider_name"] = (
f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
)
# update google to langgenius/gemini/google etc.
retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID(
retrieval_model["reranking_model"]["reranking_provider_name"]
).to_string()
retrieval_model_changed = True
click.echo(
@@ -86,9 +87,11 @@ limit 1000"""
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
params["retrieval_model"] = json.dumps(retrieval_model)
params["provider_name"] = ModelProviderID(provider_name).to_string()
sql = f"""update {table_name}
set {provider_column_name} =
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
:provider_name
{update_retrieval_model_sql}
where id = :record_id"""
conn.execute(db.text(sql), params)
@@ -122,7 +125,9 @@ limit 1000"""
)
@classmethod
def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
def migrate_db_records(
cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
) -> None:
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0
@@ -166,7 +171,8 @@ limit 1000"""
)
try:
updated_value = f"{DEFAULT_PLUGIN_ID}/{provider_name}/{provider_name}"
# update jina to langgenius/jina_tool/jina etc.
updated_value = provider_cls(provider_name).to_string()
batch_updates.append((updated_value, record_id))
except Exception as e:
failed_ids.append(record_id)