feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
@@ -8,6 +7,7 @@ from typing import Optional, Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@@ -77,7 +77,7 @@ class LargeLanguageModel(AIModel):
|
||||
|
||||
callbacks = callbacks or []
|
||||
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
if dify_config.DEBUG:
|
||||
callbacks.append(LoggingCallback())
|
||||
|
||||
# trigger before invoke callbacks
|
||||
@@ -107,7 +107,16 @@ class LargeLanguageModel(AIModel):
|
||||
callbacks=callbacks,
|
||||
)
|
||||
else:
|
||||
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
result = self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
self._trigger_invoke_error_callbacks(
|
||||
model=model,
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
@@ -22,8 +23,14 @@ class TTSModel(AIModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
):
|
||||
self,
|
||||
model: str,
|
||||
tenant_id: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: Optional[str] = None,
|
||||
) -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@@ -50,8 +57,14 @@ class TTSModel(AIModel):
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
):
|
||||
self,
|
||||
model: str,
|
||||
tenant_id: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: Optional[str] = None,
|
||||
) -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@@ -68,25 +81,25 @@ class TTSModel(AIModel):
|
||||
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||
"""
|
||||
Get voice for given tts model voices
|
||||
Retrieves the list of voices supported by a given text-to-speech (TTS) model.
|
||||
|
||||
:param language: tts language
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: voices lists
|
||||
:param language: The language for which the voices are requested.
|
||||
:param model: The name of the TTS model.
|
||||
:param credentials: The credentials required to access the TTS model.
|
||||
:return: A list of voices supported by the TTS model.
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
|
||||
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
||||
if language:
|
||||
return [
|
||||
{"name": d["name"], "value": d["mode"]}
|
||||
for d in voices
|
||||
if language and language in d.get("language")
|
||||
]
|
||||
else:
|
||||
return [{"name": d["name"], "value": d["mode"]} for d in voices]
|
||||
if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support voice")
|
||||
|
||||
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
||||
if language:
|
||||
return [
|
||||
{"name": d["name"], "value": d["mode"]} for d in voices if language and language in d.get("language")
|
||||
]
|
||||
else:
|
||||
return [{"name": d["name"], "value": d["mode"]} for d in voices]
|
||||
|
||||
def _get_model_default_voice(self, model: str, credentials: dict) -> Any:
|
||||
"""
|
||||
@@ -111,8 +124,10 @@ class TTSModel(AIModel):
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
||||
if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support audio type")
|
||||
|
||||
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
||||
|
||||
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
@@ -121,8 +136,10 @@ class TTSModel(AIModel):
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
||||
if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support word limit")
|
||||
|
||||
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
||||
|
||||
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
@@ -131,8 +148,10 @@ class TTSModel(AIModel):
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support max workers")
|
||||
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
|
||||
@staticmethod
|
||||
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):
|
||||
|
@@ -1,3 +1,4 @@
|
||||
- gpt-4o-audio-preview
|
||||
- gpt-4
|
||||
- gpt-4o
|
||||
- gpt-4o-2024-05-13
|
||||
|
@@ -0,0 +1,44 @@
|
||||
model: gpt-4o-audio-preview
|
||||
label:
|
||||
zh_Hans: gpt-4o-audio-preview
|
||||
en_US: gpt-4o-audio-preview
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '5.00'
|
||||
output: '15.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
import tiktoken
|
||||
from openai import OpenAI, Stream
|
||||
@@ -11,9 +11,9 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
@@ -23,6 +23,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType, PriceConfig
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@@ -613,6 +614,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
# clear illegal prompt messages
|
||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
# o1 compatibility
|
||||
block_as_stream = False
|
||||
if model.startswith("o1"):
|
||||
if stream:
|
||||
@@ -626,8 +628,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
del extra_model_kwargs["stop"]
|
||||
|
||||
# chat model
|
||||
messages: Any = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
response = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
messages=messages,
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
@@ -946,23 +949,29 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
Convert PromptMessage to dict for OpenAI API
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
elif isinstance(message.content, list):
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
if isinstance(message_content, TextPromptMessageContent):
|
||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
elif isinstance(message_content, ImagePromptMessageContent):
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif isinstance(message_content, AudioPromptMessageContent):
|
||||
sub_message_dict = {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": message_content.data,
|
||||
"format": message_content.format,
|
||||
},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
|
Reference in New Issue
Block a user