fix: old custom model not display credential name (#25112)

This commit is contained in:
非法操作
2025-09-04 10:46:10 +08:00
committed by GitHub
parent c22b325c31
commit 0a0ae16bd6

View File

@@ -150,6 +150,9 @@ class ProviderManager:
tenant_id
)
# Get All provider model credentials
provider_name_to_provider_model_credentials_dict = self._get_all_provider_model_credentials(tenant_id)
provider_configurations = ProviderConfigurations(tenant_id=tenant_id)
# Construct ProviderConfiguration objects for each provider
@@ -171,10 +174,18 @@ class ProviderManager:
provider_model_records.extend(
provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, [])
)
provider_model_credentials = provider_name_to_provider_model_credentials_dict.get(
provider_entity.provider, []
)
provider_id_entity = ModelProviderID(provider_name)
if provider_id_entity.is_langgenius():
provider_model_credentials.extend(
provider_name_to_provider_model_credentials_dict.get(provider_id_entity.provider_name, [])
)
# Convert to custom configuration
custom_configuration = self._to_custom_configuration(
tenant_id, provider_entity, provider_records, provider_model_records
tenant_id, provider_entity, provider_records, provider_model_records, provider_model_credentials
)
# Convert to system configuration
@@ -453,6 +464,24 @@ class ProviderManager:
)
return provider_name_to_provider_model_settings_dict
@staticmethod
def _get_all_provider_model_credentials(tenant_id: str) -> dict[str, list[ProviderModelCredential]]:
"""
Get All provider model credentials of the workspace.
:param tenant_id: workspace id
:return:
"""
provider_name_to_provider_model_credentials_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
provider_model_credentials = session.scalars(stmt)
for provider_model_credential in provider_model_credentials:
provider_name_to_provider_model_credentials_dict[provider_model_credential.provider_name].append(
provider_model_credential
)
return provider_name_to_provider_model_credentials_dict
@staticmethod
def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]:
"""
@@ -539,23 +568,6 @@ class ProviderManager:
for credential in available_credentials
]
@staticmethod
def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]:
"""
Get all the credentials records from ProviderModelCredential by provider_name
:param tenant_id: workspace id
:param provider_name: provider name
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name
)
all_credentials = session.scalars(stmt).all()
return all_credentials
@staticmethod
def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
@@ -632,6 +644,7 @@ class ProviderManager:
provider_entity: ProviderEntity,
provider_records: list[Provider],
provider_model_records: list[ProviderModel],
provider_model_credentials: list[ProviderModelCredential],
) -> CustomConfiguration:
"""
Convert to custom configuration.
@@ -647,15 +660,12 @@ class ProviderManager:
tenant_id, provider_entity, provider_records
)
# Get all model credentials once
all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider)
# Get custom models which have not been added to the model list yet
unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials)
unadded_models = self._get_can_added_models(provider_model_records, provider_model_credentials)
# Get custom model configurations
custom_model_configurations = self._get_custom_model_configurations(
tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials
tenant_id, provider_entity, provider_model_records, unadded_models, provider_model_credentials
)
can_added_models = [