fix(plugin/migrations) refactor data migration to use specific provider ID classes. (#21187)
This commit is contained in:
@@ -3,7 +3,7 @@ import logging
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from core.entities import DEFAULT_PLUGIN_ID
|
from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
|
||||||
from models.engine import db
|
from models.engine import db
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -12,17 +12,17 @@ logger = logging.getLogger(__name__)
|
|||||||
class PluginDataMigration:
|
class PluginDataMigration:
|
||||||
@classmethod
|
@classmethod
|
||||||
def migrate(cls) -> None:
|
def migrate(cls) -> None:
|
||||||
cls.migrate_db_records("providers", "provider_name") # large table
|
cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
|
||||||
cls.migrate_db_records("provider_models", "provider_name")
|
cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
|
||||||
cls.migrate_db_records("provider_orders", "provider_name")
|
cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
|
||||||
cls.migrate_db_records("tenant_default_models", "provider_name")
|
cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID)
|
||||||
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
|
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID)
|
||||||
cls.migrate_db_records("provider_model_settings", "provider_name")
|
cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID)
|
||||||
cls.migrate_db_records("load_balancing_model_configs", "provider_name")
|
cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID)
|
||||||
cls.migrate_datasets()
|
cls.migrate_datasets()
|
||||||
cls.migrate_db_records("embeddings", "provider_name") # large table
|
cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table
|
||||||
cls.migrate_db_records("dataset_collection_bindings", "provider_name")
|
cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID)
|
||||||
cls.migrate_db_records("tool_builtin_providers", "provider")
|
cls.migrate_db_records("tool_builtin_providers", "provider_name", ToolProviderID)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def migrate_datasets(cls) -> None:
|
def migrate_datasets(cls) -> None:
|
||||||
@@ -66,9 +66,10 @@ limit 1000"""
|
|||||||
fg="white",
|
fg="white",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
retrieval_model["reranking_model"]["reranking_provider_name"] = (
|
# update google to langgenius/gemini/google etc.
|
||||||
f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
|
retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID(
|
||||||
)
|
retrieval_model["reranking_model"]["reranking_provider_name"]
|
||||||
|
).to_string()
|
||||||
retrieval_model_changed = True
|
retrieval_model_changed = True
|
||||||
|
|
||||||
click.echo(
|
click.echo(
|
||||||
@@ -86,9 +87,11 @@ limit 1000"""
|
|||||||
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
|
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
|
||||||
params["retrieval_model"] = json.dumps(retrieval_model)
|
params["retrieval_model"] = json.dumps(retrieval_model)
|
||||||
|
|
||||||
|
params["provider_name"] = ModelProviderID(provider_name).to_string()
|
||||||
|
|
||||||
sql = f"""update {table_name}
|
sql = f"""update {table_name}
|
||||||
set {provider_column_name} =
|
set {provider_column_name} =
|
||||||
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
|
:provider_name
|
||||||
{update_retrieval_model_sql}
|
{update_retrieval_model_sql}
|
||||||
where id = :record_id"""
|
where id = :record_id"""
|
||||||
conn.execute(db.text(sql), params)
|
conn.execute(db.text(sql), params)
|
||||||
@@ -122,7 +125,9 @@ limit 1000"""
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
|
def migrate_db_records(
|
||||||
|
cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
|
||||||
|
) -> None:
|
||||||
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||||
|
|
||||||
processed_count = 0
|
processed_count = 0
|
||||||
@@ -166,7 +171,8 @@ limit 1000"""
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
updated_value = f"{DEFAULT_PLUGIN_ID}/{provider_name}/{provider_name}"
|
# update jina to langgenius/jina_tool/jina etc.
|
||||||
|
updated_value = provider_cls(provider_name).to_string()
|
||||||
batch_updates.append((updated_value, record_id))
|
batch_updates.append((updated_value, record_id))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
failed_ids.append(record_id)
|
failed_ids.append(record_id)
|
||||||
|
Reference in New Issue
Block a user