fix(plugin/migrations) refactor data migration to use specific provider ID classes. (#21187)
This commit is contained in:
@@ -3,7 +3,7 @@ import logging
|
||||
|
||||
import click
|
||||
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
|
||||
from models.engine import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -12,17 +12,17 @@ logger = logging.getLogger(__name__)
|
||||
class PluginDataMigration:
|
||||
@classmethod
|
||||
def migrate(cls) -> None:
|
||||
cls.migrate_db_records("providers", "provider_name") # large table
|
||||
cls.migrate_db_records("provider_models", "provider_name")
|
||||
cls.migrate_db_records("provider_orders", "provider_name")
|
||||
cls.migrate_db_records("tenant_default_models", "provider_name")
|
||||
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
|
||||
cls.migrate_db_records("provider_model_settings", "provider_name")
|
||||
cls.migrate_db_records("load_balancing_model_configs", "provider_name")
|
||||
cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
|
||||
cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID)
|
||||
cls.migrate_datasets()
|
||||
cls.migrate_db_records("embeddings", "provider_name") # large table
|
||||
cls.migrate_db_records("dataset_collection_bindings", "provider_name")
|
||||
cls.migrate_db_records("tool_builtin_providers", "provider")
|
||||
cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table
|
||||
cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("tool_builtin_providers", "provider_name", ToolProviderID)
|
||||
|
||||
@classmethod
|
||||
def migrate_datasets(cls) -> None:
|
||||
@@ -66,9 +66,10 @@ limit 1000"""
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
retrieval_model["reranking_model"]["reranking_provider_name"] = (
|
||||
f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
|
||||
)
|
||||
# update google to langgenius/gemini/google etc.
|
||||
retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID(
|
||||
retrieval_model["reranking_model"]["reranking_provider_name"]
|
||||
).to_string()
|
||||
retrieval_model_changed = True
|
||||
|
||||
click.echo(
|
||||
@@ -86,9 +87,11 @@ limit 1000"""
|
||||
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
|
||||
params["retrieval_model"] = json.dumps(retrieval_model)
|
||||
|
||||
params["provider_name"] = ModelProviderID(provider_name).to_string()
|
||||
|
||||
sql = f"""update {table_name}
|
||||
set {provider_column_name} =
|
||||
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
|
||||
:provider_name
|
||||
{update_retrieval_model_sql}
|
||||
where id = :record_id"""
|
||||
conn.execute(db.text(sql), params)
|
||||
@@ -122,7 +125,9 @@ limit 1000"""
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
|
||||
def migrate_db_records(
|
||||
cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
|
||||
) -> None:
|
||||
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||
|
||||
processed_count = 0
|
||||
@@ -166,7 +171,8 @@ limit 1000"""
|
||||
)
|
||||
|
||||
try:
|
||||
updated_value = f"{DEFAULT_PLUGIN_ID}/{provider_name}/{provider_name}"
|
||||
# update jina to langgenius/jina_tool/jina etc.
|
||||
updated_value = provider_cls(provider_name).to_string()
|
||||
batch_updates.append((updated_value, record_id))
|
||||
except Exception as e:
|
||||
failed_ids.append(record_id)
|
||||
|
Reference in New Issue
Block a user