Azure openai init (#1929)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
@@ -32,7 +32,7 @@ class ModelType(Enum):
|
||||
return cls.TEXT_EMBEDDING
|
||||
elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
|
||||
return cls.RERANK
|
||||
elif origin_model_type == cls.SPEECH2TEXT.value:
|
||||
elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type == cls.MODERATION.value:
|
||||
return cls.MODERATION
|
||||
|
@@ -2,7 +2,7 @@ from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \
|
||||
DefaultParameterName, PriceConfig
|
||||
DefaultParameterName, PriceConfig, ModelPropertyKey
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject
|
||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
|
||||
@@ -502,8 +502,8 @@ EMBEDDING_BASE_MODELS = [
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
'context_size': 8097,
|
||||
'max_chunks': 32,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 8097,
|
||||
ModelPropertyKey.MAX_CHUNKS: 32,
|
||||
},
|
||||
pricing=PriceConfig(
|
||||
input=0.0001,
|
||||
|
@@ -30,7 +30,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
||||
|
||||
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
@@ -59,7 +59,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
|
||||
model_mode = self._get_ai_model_entity(credentials['base_model_name'], model).entity.model_properties.get(
|
||||
model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
|
||||
ModelPropertyKey.MODE)
|
||||
|
||||
if model_mode == LLMMode.CHAT.value:
|
||||
@@ -79,7 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
if 'base_model_name' not in credentials:
|
||||
raise CredentialsValidateFailedError('Base Model Name is required')
|
||||
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
||||
|
||||
if not ai_model_entity:
|
||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||
@@ -109,8 +109,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
return ai_model_entity.entity
|
||||
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
||||
return ai_model_entity.entity if ai_model_entity else None
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
|
||||
|
Reference in New Issue
Block a user