Azure openai init (#1929)

Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Charlie.Wei
2024-01-09 19:17:47 +08:00
committed by GitHub
parent b8592ad412
commit 5b24d7129e
8 changed files with 151 additions and 34 deletions

View File

@@ -32,7 +32,7 @@ class ModelType(Enum):
return cls.TEXT_EMBEDDING
elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
return cls.RERANK
elif origin_model_type == cls.SPEECH2TEXT.value:
elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
return cls.SPEECH2TEXT
elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION

View File

@@ -2,7 +2,7 @@ from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \
DefaultParameterName, PriceConfig
DefaultParameterName, PriceConfig, ModelPropertyKey
from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
@@ -502,8 +502,8 @@ EMBEDDING_BASE_MODELS = [
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
'context_size': 8097,
'max_chunks': 32,
ModelPropertyKey.CONTEXT_SIZE: 8097,
ModelPropertyKey.MAX_CHUNKS: 32,
},
pricing=PriceConfig(
input=0.0001,

View File

@@ -30,7 +30,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model
@@ -59,7 +59,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
model_mode = self._get_ai_model_entity(credentials['base_model_name'], model).entity.model_properties.get(
model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
ModelPropertyKey.MODE)
if model_mode == LLMMode.CHAT.value:
@@ -79,7 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
if 'base_model_name' not in credentials:
raise CredentialsValidateFailedError('Base Model Name is required')
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
if not ai_model_entity:
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
@@ -109,8 +109,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
return ai_model_entity.entity
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
return ai_model_entity.entity if ai_model_entity else None
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,