fix: old custom model not display credential name (#25112)
This commit is contained in:
@@ -150,6 +150,9 @@ class ProviderManager:
|
||||
tenant_id
|
||||
)
|
||||
|
||||
# Get All provider model credentials
|
||||
provider_name_to_provider_model_credentials_dict = self._get_all_provider_model_credentials(tenant_id)
|
||||
|
||||
provider_configurations = ProviderConfigurations(tenant_id=tenant_id)
|
||||
|
||||
# Construct ProviderConfiguration objects for each provider
|
||||
@@ -171,10 +174,18 @@ class ProviderManager:
|
||||
provider_model_records.extend(
|
||||
provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, [])
|
||||
)
|
||||
provider_model_credentials = provider_name_to_provider_model_credentials_dict.get(
|
||||
provider_entity.provider, []
|
||||
)
|
||||
provider_id_entity = ModelProviderID(provider_name)
|
||||
if provider_id_entity.is_langgenius():
|
||||
provider_model_credentials.extend(
|
||||
provider_name_to_provider_model_credentials_dict.get(provider_id_entity.provider_name, [])
|
||||
)
|
||||
|
||||
# Convert to custom configuration
|
||||
custom_configuration = self._to_custom_configuration(
|
||||
tenant_id, provider_entity, provider_records, provider_model_records
|
||||
tenant_id, provider_entity, provider_records, provider_model_records, provider_model_credentials
|
||||
)
|
||||
|
||||
# Convert to system configuration
|
||||
@@ -453,6 +464,24 @@ class ProviderManager:
|
||||
)
|
||||
return provider_name_to_provider_model_settings_dict
|
||||
|
||||
@staticmethod
|
||||
def _get_all_provider_model_credentials(tenant_id: str) -> dict[str, list[ProviderModelCredential]]:
|
||||
"""
|
||||
Get All provider model credentials of the workspace.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_credentials_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
|
||||
provider_model_credentials = session.scalars(stmt)
|
||||
for provider_model_credential in provider_model_credentials:
|
||||
provider_name_to_provider_model_credentials_dict[provider_model_credential.provider_name].append(
|
||||
provider_model_credential
|
||||
)
|
||||
return provider_name_to_provider_model_credentials_dict
|
||||
|
||||
@staticmethod
|
||||
def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]:
|
||||
"""
|
||||
@@ -539,23 +568,6 @@ class ProviderManager:
|
||||
for credential in available_credentials
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]:
|
||||
"""
|
||||
Get all the credentials records from ProviderModelCredential by provider_name
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name: provider name
|
||||
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name
|
||||
)
|
||||
|
||||
all_credentials = session.scalars(stmt).all()
|
||||
return all_credentials
|
||||
|
||||
@staticmethod
|
||||
def _init_trial_provider_records(
|
||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
|
||||
@@ -632,6 +644,7 @@ class ProviderManager:
|
||||
provider_entity: ProviderEntity,
|
||||
provider_records: list[Provider],
|
||||
provider_model_records: list[ProviderModel],
|
||||
provider_model_credentials: list[ProviderModelCredential],
|
||||
) -> CustomConfiguration:
|
||||
"""
|
||||
Convert to custom configuration.
|
||||
@@ -647,15 +660,12 @@ class ProviderManager:
|
||||
tenant_id, provider_entity, provider_records
|
||||
)
|
||||
|
||||
# Get all model credentials once
|
||||
all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider)
|
||||
|
||||
# Get custom models which have not been added to the model list yet
|
||||
unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials)
|
||||
unadded_models = self._get_can_added_models(provider_model_records, provider_model_credentials)
|
||||
|
||||
# Get custom model configurations
|
||||
custom_model_configurations = self._get_custom_model_configurations(
|
||||
tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials
|
||||
tenant_id, provider_entity, provider_model_records, unadded_models, provider_model_credentials
|
||||
)
|
||||
|
||||
can_added_models = [
|
||||
|
Reference in New Issue
Block a user