diff --git a/api/services/plugin/data_migration.py b/api/services/plugin/data_migration.py index 1c5abfecb..02de5a79d 100644 --- a/api/services/plugin/data_migration.py +++ b/api/services/plugin/data_migration.py @@ -3,7 +3,7 @@ import logging import click -from core.entities import DEFAULT_PLUGIN_ID +from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID from models.engine import db logger = logging.getLogger(__name__) @@ -12,17 +12,17 @@ logger = logging.getLogger(__name__) class PluginDataMigration: @classmethod def migrate(cls) -> None: - cls.migrate_db_records("providers", "provider_name") # large table - cls.migrate_db_records("provider_models", "provider_name") - cls.migrate_db_records("provider_orders", "provider_name") - cls.migrate_db_records("tenant_default_models", "provider_name") - cls.migrate_db_records("tenant_preferred_model_providers", "provider_name") - cls.migrate_db_records("provider_model_settings", "provider_name") - cls.migrate_db_records("load_balancing_model_configs", "provider_name") + cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table + cls.migrate_db_records("provider_models", "provider_name", ModelProviderID) + cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID) + cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID) + cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID) + cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID) + cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID) cls.migrate_datasets() - cls.migrate_db_records("embeddings", "provider_name") # large table - cls.migrate_db_records("dataset_collection_bindings", "provider_name") - cls.migrate_db_records("tool_builtin_providers", "provider") + cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table + cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID) + cls.migrate_db_records("tool_builtin_providers", "provider_name", ToolProviderID) @classmethod def migrate_datasets(cls) -> None: @@ -66,9 +66,10 @@ limit 1000""" fg="white", ) ) - retrieval_model["reranking_model"]["reranking_provider_name"] = ( - f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}" - ) + # update google to langgenius/gemini/google etc. + retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID( + retrieval_model["reranking_model"]["reranking_provider_name"] + ).to_string() retrieval_model_changed = True click.echo( @@ -86,9 +87,11 @@ limit 1000""" update_retrieval_model_sql = ", retrieval_model = :retrieval_model" params["retrieval_model"] = json.dumps(retrieval_model) + params["provider_name"] = ModelProviderID(provider_name).to_string() + sql = f"""update {table_name} set {provider_column_name} = - concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name}) + :provider_name {update_retrieval_model_sql} where id = :record_id""" conn.execute(db.text(sql), params) @@ -122,7 +125,9 @@ limit 1000""" ) @classmethod - def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None: + def migrate_db_records( + cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID] + ) -> None: click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white")) processed_count = 0 @@ -166,7 +171,8 @@ limit 1000""" ) try: - updated_value = f"{DEFAULT_PLUGIN_ID}/{provider_name}/{provider_name}" + # update jina to langgenius/jina_tool/jina etc. + updated_value = provider_cls(provider_name).to_string() batch_updates.append((updated_value, record_id)) except Exception as e: failed_ids.append(record_id)