Azure openai init (#1929)

Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Charlie.Wei
2024-01-09 19:17:47 +08:00
committed by GitHub
parent b8592ad412
commit 5b24d7129e
8 changed files with 151 additions and 34 deletions

View File

@@ -4,13 +4,14 @@ from typing import Optional
from flask import Flask
from pydantic import BaseModel
from core.entities.provider_entities import QuotaUnit
from core.entities.provider_entities import QuotaUnit, RestrictModel
from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType
class HostingQuota(BaseModel):
quota_type: ProviderQuotaType
restrict_llms: list[str] = []
restrict_models: list[RestrictModel] = []
class TrialHostingQuota(HostingQuota):
@@ -47,10 +48,9 @@ class HostingConfiguration:
provider_map: dict[str, HostingProvider] = {}
moderation_config: HostedModerationConfig = None
def init_app(self, app: Flask):
if app.config.get('EDITION') != 'CLOUD':
return
def init_app(self, app: Flask) -> None:
self.provider_map["azure_openai"] = self.init_azure_openai()
self.provider_map["openai"] = self.init_openai()
self.provider_map["anthropic"] = self.init_anthropic()
self.provider_map["minimax"] = self.init_minimax()
@@ -59,6 +59,47 @@ class HostingConfiguration:
self.moderation_config = self.init_moderation_config()
def init_azure_openai(self) -> HostingProvider:
quota_unit = QuotaUnit.TIMES
if os.environ.get("HOSTED_AZURE_OPENAI_ENABLED") and os.environ.get("HOSTED_AZURE_OPENAI_ENABLED").lower() == 'true':
credentials = {
"openai_api_key": os.environ.get("HOSTED_AZURE_OPENAI_API_KEY"),
"openai_api_base": os.environ.get("HOSTED_AZURE_OPENAI_API_BASE"),
"base_model_name": "gpt-35-turbo"
}
quotas = []
hosted_quota_limit = int(os.environ.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_models=[
RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
]
)
quotas.append(trial_quota)
return HostingProvider(
enabled=True,
credentials=credentials,
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_openai(self) -> HostingProvider:
quota_unit = QuotaUnit.TIMES
if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true':
@@ -77,12 +118,12 @@ class HostingConfiguration:
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_llms=[
"gpt-3.5-turbo",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-instruct",
"gpt-3.5-turbo-16k",
"text-davinci-003"
restrict_models=[
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
]
)
quotas.append(trial_quota)