fix(plugin/migrations) refactor data migration to use specific provider ID classes. (#21187)

This commit is contained in:
Yeuoly
2025-06-19 13:02:39 +08:00
committed by GitHub
parent 2c04a16eaa
commit 2020a31785

View File

@@ -3,7 +3,7 @@ import logging
import click import click
from core.entities import DEFAULT_PLUGIN_ID from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
from models.engine import db from models.engine import db
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -12,17 +12,17 @@ logger = logging.getLogger(__name__)
class PluginDataMigration: class PluginDataMigration:
@classmethod @classmethod
def migrate(cls) -> None: def migrate(cls) -> None:
cls.migrate_db_records("providers", "provider_name") # large table cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
cls.migrate_db_records("provider_models", "provider_name") cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
cls.migrate_db_records("provider_orders", "provider_name") cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
cls.migrate_db_records("tenant_default_models", "provider_name") cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID)
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name") cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID)
cls.migrate_db_records("provider_model_settings", "provider_name") cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID)
cls.migrate_db_records("load_balancing_model_configs", "provider_name") cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID)
cls.migrate_datasets() cls.migrate_datasets()
cls.migrate_db_records("embeddings", "provider_name") # large table cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table
cls.migrate_db_records("dataset_collection_bindings", "provider_name") cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID)
cls.migrate_db_records("tool_builtin_providers", "provider") cls.migrate_db_records("tool_builtin_providers", "provider_name", ToolProviderID)
@classmethod @classmethod
def migrate_datasets(cls) -> None: def migrate_datasets(cls) -> None:
@@ -66,9 +66,10 @@ limit 1000"""
fg="white", fg="white",
) )
) )
retrieval_model["reranking_model"]["reranking_provider_name"] = ( # update google to langgenius/gemini/google etc.
f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}" retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID(
) retrieval_model["reranking_model"]["reranking_provider_name"]
).to_string()
retrieval_model_changed = True retrieval_model_changed = True
click.echo( click.echo(
@@ -86,9 +87,11 @@ limit 1000"""
update_retrieval_model_sql = ", retrieval_model = :retrieval_model" update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
params["retrieval_model"] = json.dumps(retrieval_model) params["retrieval_model"] = json.dumps(retrieval_model)
params["provider_name"] = ModelProviderID(provider_name).to_string()
sql = f"""update {table_name} sql = f"""update {table_name}
set {provider_column_name} = set {provider_column_name} =
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name}) :provider_name
{update_retrieval_model_sql} {update_retrieval_model_sql}
where id = :record_id""" where id = :record_id"""
conn.execute(db.text(sql), params) conn.execute(db.text(sql), params)
@@ -122,7 +125,9 @@ limit 1000"""
) )
@classmethod @classmethod
def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None: def migrate_db_records(
cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
) -> None:
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white")) click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0 processed_count = 0
@@ -166,7 +171,8 @@ limit 1000"""
) )
try: try:
updated_value = f"{DEFAULT_PLUGIN_ID}/{provider_name}/{provider_name}" # update jina to langgenius/jina_tool/jina etc.
updated_value = provider_cls(provider_name).to_string()
batch_updates.append((updated_value, record_id)) batch_updates.append((updated_value, record_id))
except Exception as e: except Exception as e:
failed_ids.append(record_id) failed_ids.append(record_id)