Azure openai init (#1929)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional, List, Dict, Tuple, Iterator
|
||||
|
||||
@@ -11,8 +11,9 @@ from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, S
|
||||
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
|
||||
from core.model_runtime.entities.model_entities import ModelType, FetchFrom
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
|
||||
ConfigurateMethod
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
@@ -22,6 +23,8 @@ from models.provider import ProviderType, Provider, ProviderModel, TenantPreferr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
original_provider_configurate_methods = {}
|
||||
|
||||
|
||||
class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
@@ -34,6 +37,20 @@ class ProviderConfiguration(BaseModel):
|
||||
system_configuration: SystemConfiguration
|
||||
custom_configuration: CustomConfiguration
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
if self.provider.provider not in original_provider_configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider] = []
|
||||
for configurate_method in self.provider.configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
||||
|
||||
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
if (any([len(quota_configuration.restrict_models) > 0
|
||||
for quota_configuration in self.system_configuration.quota_configurations])
|
||||
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
|
||||
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
|
||||
"""
|
||||
Get current credentials.
|
||||
@@ -123,7 +140,8 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if provider_record:
|
||||
try:
|
||||
original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
|
||||
original_credentials = json.loads(
|
||||
provider_record.encrypted_config) if provider_record.encrypted_config else {}
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
@@ -265,7 +283,8 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if provider_model_record:
|
||||
try:
|
||||
original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
||||
original_credentials = json.loads(
|
||||
provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
@@ -534,21 +553,70 @@ class ProviderConfiguration(BaseModel):
|
||||
]
|
||||
)
|
||||
|
||||
if self.provider.provider not in original_provider_configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider] = []
|
||||
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
||||
|
||||
should_use_custom_model = False
|
||||
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
should_use_custom_model = True
|
||||
|
||||
for quota_configuration in self.system_configuration.quota_configurations:
|
||||
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
|
||||
continue
|
||||
|
||||
restrict_llms = quota_configuration.restrict_llms
|
||||
if not restrict_llms:
|
||||
restrict_models = quota_configuration.restrict_models
|
||||
if len(restrict_models) == 0:
|
||||
break
|
||||
|
||||
if should_use_custom_model:
|
||||
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
# only customizable model
|
||||
for restrict_model in restrict_models:
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_model.base_model_name:
|
||||
copy_credentials['base_model_name'] = restrict_model.base_model_name
|
||||
|
||||
try:
|
||||
custom_model_schema = (
|
||||
provider_instance.get_model_instance(restrict_model.model_type)
|
||||
.get_customizable_model_schema_from_credentials(
|
||||
restrict_model.model,
|
||||
copy_credentials
|
||||
)
|
||||
)
|
||||
except Exception as ex:
|
||||
logger.warning(f'get custom model schema failed, {ex}')
|
||||
continue
|
||||
|
||||
if not custom_model_schema:
|
||||
continue
|
||||
|
||||
if custom_model_schema.model_type not in model_types:
|
||||
continue
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
model=custom_model_schema.model,
|
||||
label=custom_model_schema.label,
|
||||
model_type=custom_model_schema.model_type,
|
||||
features=custom_model_schema.features,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE
|
||||
)
|
||||
)
|
||||
|
||||
# if llm name not in restricted llm list, remove it
|
||||
restrict_model_names = [rm.model for rm in restrict_models]
|
||||
for m in provider_models:
|
||||
if m.model_type == ModelType.LLM and m.model not in restrict_llms:
|
||||
if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
|
||||
m.status = ModelStatus.NO_PERMISSION
|
||||
elif not quota_configuration.is_valid:
|
||||
m.status = ModelStatus.QUOTA_EXCEEDED
|
||||
|
||||
return provider_models
|
||||
|
||||
def _get_custom_provider_models(self,
|
||||
|
Reference in New Issue
Block a user