Azure openai init (#1929)

Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Charlie.Wei
2024-01-09 19:17:47 +08:00
committed by GitHub
parent b8592ad412
commit 5b24d7129e
8 changed files with 151 additions and 34 deletions

View File

@@ -1,7 +1,7 @@
import datetime
import json
import logging
import time
from json import JSONDecodeError
from typing import Optional, List, Dict, Tuple, Iterator
@@ -11,8 +11,9 @@ from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, S
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
from core.model_runtime.entities.model_entities import ModelType, FetchFrom
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
ConfigurateMethod
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
@@ -22,6 +23,8 @@ from models.provider import ProviderType, Provider, ProviderModel, TenantPreferr
logger = logging.getLogger(__name__)
original_provider_configurate_methods = {}
class ProviderConfiguration(BaseModel):
"""
@@ -34,6 +37,20 @@ class ProviderConfiguration(BaseModel):
system_configuration: SystemConfiguration
custom_configuration: CustomConfiguration
def __init__(self, **data):
super().__init__(**data)
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in self.provider.configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
if (any([len(quota_configuration.restrict_models) > 0
for quota_configuration in self.system_configuration.quota_configurations])
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
"""
Get current credentials.
@@ -123,7 +140,8 @@ class ProviderConfiguration(BaseModel):
if provider_record:
try:
original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
original_credentials = json.loads(
provider_record.encrypted_config) if provider_record.encrypted_config else {}
except JSONDecodeError:
original_credentials = {}
@@ -265,7 +283,8 @@ class ProviderConfiguration(BaseModel):
if provider_model_record:
try:
original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
original_credentials = json.loads(
provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
except JSONDecodeError:
original_credentials = {}
@@ -534,21 +553,70 @@ class ProviderConfiguration(BaseModel):
]
)
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
should_use_custom_model = False
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
should_use_custom_model = True
for quota_configuration in self.system_configuration.quota_configurations:
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
continue
restrict_llms = quota_configuration.restrict_llms
if not restrict_llms:
restrict_models = quota_configuration.restrict_models
if len(restrict_models) == 0:
break
if should_use_custom_model:
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
# only customizable model
for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name:
copy_credentials['base_model_name'] = restrict_model.base_model_name
try:
custom_model_schema = (
provider_instance.get_model_instance(restrict_model.model_type)
.get_customizable_model_schema_from_credentials(
restrict_model.model,
copy_credentials
)
)
except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}')
continue
if not custom_model_schema:
continue
if custom_model_schema.model_type not in model_types:
continue
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
)
# if llm name not in restricted llm list, remove it
restrict_model_names = [rm.model for rm in restrict_models]
for m in provider_models:
if m.model_type == ModelType.LLM and m.model not in restrict_llms:
if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
m.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid:
m.status = ModelStatus.QUOTA_EXCEEDED
return provider_models
def _get_custom_provider_models(self,