feat: support more model types and builtin tools on aws/sagemaker (#8061)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
This commit is contained in:
@@ -1,17 +1,36 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
import re
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
# from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
import boto3
|
||||
from sagemaker import Predictor, serializers
|
||||
from sagemaker.session import Session
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
I18nObject,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
@@ -25,12 +44,140 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], stop:list, stream=False):
|
||||
"""
|
||||
params:
|
||||
predictor : Sagemaker Predictor
|
||||
messages (List[Dict[str,Any]]): message list。
|
||||
messages = [
|
||||
{"role": "system", "content":"please answer in Chinese"},
|
||||
{"role": "user", "content": "who are you? what are you doing?"},
|
||||
]
|
||||
params (Dict[str,Any]): model parameters for LLM。
|
||||
stream (bool): False by default。
|
||||
|
||||
response:
|
||||
result of inference if stream is False
|
||||
Iterator of Chunks if stream is True
|
||||
"""
|
||||
payload = {
|
||||
"model" : params.get('model_name'),
|
||||
"stop" : stop,
|
||||
"messages": messages,
|
||||
"stream" : stream,
|
||||
"max_tokens" : params.get('max_new_tokens', params.get('max_tokens', 2048)),
|
||||
"temperature" : params.get('temperature', 0.1),
|
||||
"top_p" : params.get('top_p', 0.9),
|
||||
}
|
||||
|
||||
if not stream:
|
||||
response = predictor.predict(payload)
|
||||
return response
|
||||
else:
|
||||
response_stream = predictor.predict_stream(payload)
|
||||
return response_stream
|
||||
|
||||
class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
Model class for Cohere large language model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_sess : Any = None
|
||||
predictor : Any = None
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
resp: bytes) -> LLMResult:
|
||||
"""
|
||||
handle normal chat generate response
|
||||
"""
|
||||
resp_obj = json.loads(resp.decode('utf-8'))
|
||||
resp_str = resp_obj.get('choices')[0].get('message').get('content')
|
||||
|
||||
if len(resp_str) == 0:
|
||||
raise InvokeServerUnavailableError("Empty response")
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=resp_str,
|
||||
tool_calls=[]
|
||||
)
|
||||
|
||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=None,
|
||||
usage=usage,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
resp: Iterator[bytes]) -> Generator:
|
||||
"""
|
||||
handle stream chat generate response
|
||||
"""
|
||||
full_response = ''
|
||||
buffer = ""
|
||||
for chunk_bytes in resp:
|
||||
buffer += chunk_bytes.decode('utf-8')
|
||||
last_idx = 0
|
||||
for match in re.finditer(r'^data:\s*(.+?)(\n\n)', buffer):
|
||||
try:
|
||||
data = json.loads(match.group(1).strip())
|
||||
last_idx = match.span()[1]
|
||||
|
||||
if "content" in data["choices"][0]["delta"]:
|
||||
chunk_content = data["choices"][0]["delta"]["content"]
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_content,
|
||||
tool_calls=[]
|
||||
)
|
||||
|
||||
if data["choices"][0]['finish_reason'] is not None:
|
||||
temp_assistant_prompt_message = AssistantPromptMessage(
|
||||
content=full_response,
|
||||
tool_calls=[]
|
||||
)
|
||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=None,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=data["choices"][0]['finish_reason'],
|
||||
usage=usage
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=None,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=assistant_prompt_message
|
||||
),
|
||||
)
|
||||
|
||||
full_response += chunk_content
|
||||
except (json.JSONDecodeError, KeyError, IndexError) as e:
|
||||
logger.info("json parse exception, content: {}".format(match.group(1).strip()))
|
||||
pass
|
||||
|
||||
buffer = buffer[last_idx:]
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
@@ -50,9 +197,6 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model, credentials)
|
||||
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('access_key')
|
||||
secret_key = credentials.get('secret_key')
|
||||
@@ -68,37 +212,132 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client)
|
||||
self.predictor = Predictor(
|
||||
endpoint_name=credentials.get('sagemaker_endpoint'),
|
||||
sagemaker_session=sagemaker_session,
|
||||
serializer=serializers.JSONSerializer(),
|
||||
)
|
||||
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=sagemaker_endpoint,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": prompt_messages[0].content,
|
||||
"parameters": { "stop" : stop},
|
||||
"history" : []
|
||||
}
|
||||
),
|
||||
ContentType="application/json",
|
||||
)
|
||||
|
||||
assistant_text = response_model['Body'].read().decode('utf8')
|
||||
messages:list[dict[str,Any]] = [ {"role": p.role.value, "content": p.content} for p in prompt_messages ]
|
||||
response = inference(predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
)
|
||||
if stream:
|
||||
if tools and len(tools) > 0:
|
||||
raise InvokeBadRequestError(f"{model}'s tool calls does not support stream mode")
|
||||
|
||||
usage = self._calc_response_usage(model, credentials, 0, 0)
|
||||
return self._handle_chat_stream_response(model=model, credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools, resp=response)
|
||||
return self._handle_chat_generate_response(model=model, credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools, resp=response)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage
|
||||
)
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI Compatibility API
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": message_content.data,
|
||||
"detail": message_content.detail.value
|
||||
}
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls and len(message.tool_calls) > 0:
|
||||
message_dict["function_call"] = {
|
||||
"name": message.tool_calls[0].function.name,
|
||||
"arguments": message.tool_calls[0].function.arguments
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
return response
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
|
||||
is_completion_model: bool = False) -> int:
|
||||
def tokens(text: str):
|
||||
return self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if is_completion_model:
|
||||
return sum(tokens(str(message.content)) for message in messages)
|
||||
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += tokens(t_key)
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += tokens(f_key)
|
||||
num_tokens += tokens(f_value)
|
||||
else:
|
||||
num_tokens += tokens(t_key)
|
||||
num_tokens += tokens(t_value)
|
||||
if key == "function_call":
|
||||
for t_key, t_value in value.items():
|
||||
num_tokens += tokens(t_key)
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += tokens(f_key)
|
||||
num_tokens += tokens(f_value)
|
||||
else:
|
||||
num_tokens += tokens(t_key)
|
||||
num_tokens += tokens(t_value)
|
||||
else:
|
||||
num_tokens += tokens(str(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
@@ -112,10 +351,8 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
:return:
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model)
|
||||
|
||||
try:
|
||||
return 0
|
||||
return self._num_tokens_from_messages(prompt_messages, tools)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@@ -129,7 +366,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
try:
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model)
|
||||
pass
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@@ -200,13 +437,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
]
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
if completion_type == LLMMode.CHAT:
|
||||
print(f"completion_type : {LLMMode.CHAT.value}")
|
||||
|
||||
if completion_type == LLMMode.COMPLETION:
|
||||
print(f"completion_type : {LLMMode.COMPLETION.value}")
|
||||
completion_type = LLMMode.value_of(credentials["mode"]).value
|
||||
|
||||
features = []
|
||||
|
||||
|
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class SageMakerRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for Cohere rerank model.
|
||||
Model class for SageMaker rerank model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
|
||||
|
@@ -1,10 +1,11 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import IO, Any
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SageMakerProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
@@ -15,3 +16,28 @@ class SageMakerProvider(ModelProvider):
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def buffer_to_s3(s3_client:Any, file: IO[bytes], bucket:str, s3_prefix:str) -> str:
|
||||
'''
|
||||
return s3_uri of this file
|
||||
'''
|
||||
s3_key = f'{s3_prefix}{uuid.uuid4()}.mp3'
|
||||
s3_client.put_object(
|
||||
Body=file.read(),
|
||||
Bucket=bucket,
|
||||
Key=s3_key,
|
||||
ContentType='audio/mp3'
|
||||
)
|
||||
return s3_key
|
||||
|
||||
def generate_presigned_url(s3_client:Any, file: IO[bytes], bucket_name:str, s3_prefix:str, expiration=600) -> str:
|
||||
object_key = buffer_to_s3(s3_client, file, bucket_name, s3_prefix)
|
||||
try:
|
||||
response = s3_client.generate_presigned_url('get_object',
|
||||
Params={'Bucket': bucket_name, 'Key': object_key},
|
||||
ExpiresIn=expiration)
|
||||
except Exception as e:
|
||||
print(f"Error generating presigned URL: {e}")
|
||||
return None
|
||||
|
||||
return response
|
@@ -21,6 +21,8 @@ supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
- speech2text
|
||||
- tts
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
@@ -45,14 +47,10 @@ model_credential_schema:
|
||||
zh_Hans: 选择对话类型
|
||||
en_US: Select completion mode
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
zh_Hans: 补全
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
zh_Hans: 对话
|
||||
zh_Hans: Chat
|
||||
- variable: sagemaker_endpoint
|
||||
label:
|
||||
en_US: sagemaker endpoint
|
||||
@@ -61,6 +59,76 @@ model_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 请输出你的Sagemaker推理端点
|
||||
en_US: Enter your Sagemaker Inference endpoint
|
||||
- variable: audio_s3_cache_bucket
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: speech2text
|
||||
label:
|
||||
zh_Hans: 音频缓存桶(s3 bucket)
|
||||
en_US: audio cache bucket(s3 bucket)
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: sagemaker-us-east-1-******207838
|
||||
en_US: sagemaker-us-east-1-*******7838
|
||||
- variable: audio_model_type
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
label:
|
||||
en_US: Audio model type
|
||||
type: select
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 语音模型类型
|
||||
en_US: Audio model type
|
||||
options:
|
||||
- value: PresetVoice
|
||||
label:
|
||||
en_US: preset voice
|
||||
zh_Hans: 内置音色
|
||||
- value: CloneVoice
|
||||
label:
|
||||
en_US: clone voice
|
||||
zh_Hans: 克隆音色
|
||||
- value: CloneVoice_CrossLingual
|
||||
label:
|
||||
en_US: crosslingual clone voice
|
||||
zh_Hans: 跨语种克隆音色
|
||||
- value: InstructVoice
|
||||
label:
|
||||
en_US: Instruct voice
|
||||
zh_Hans: 文字指令音色
|
||||
- variable: prompt_audio
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
label:
|
||||
en_US: Mock Audio Source
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 被模仿的音色音频
|
||||
en_US: source audio to be mocked
|
||||
- variable: prompt_text
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
label:
|
||||
en_US: Prompt Audio Text
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 模仿音色的对应文本
|
||||
en_US: text for the mocked source audio
|
||||
- variable: instruct_text
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
label:
|
||||
en_US: instruct text for speaker
|
||||
type: text-input
|
||||
required: false
|
||||
- variable: aws_access_key_id
|
||||
required: false
|
||||
label:
|
||||
|
@@ -0,0 +1,142 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import IO, Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SageMakerSpeech2TextModel(Speech2TextModel):
|
||||
"""
|
||||
Model class for Xinference speech to text model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
s3_client : Any = None
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke speech2text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
asr_text = None
|
||||
|
||||
try:
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('aws_access_key_id')
|
||||
secret_key = credentials.get('aws_secret_access_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
self.s3_client = boto3.client("s3")
|
||||
|
||||
s3_prefix='dify/speech2text/'
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
bucket = credentials.get('audio_s3_cache_bucket')
|
||||
|
||||
s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix)
|
||||
payload = {
|
||||
"audio_s3_presign_uri" : s3_presign_url
|
||||
}
|
||||
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=sagemaker_endpoint,
|
||||
Body=json.dumps(payload),
|
||||
ContentType="application/json"
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
asr_text = json_obj['text']
|
||||
except Exception as e:
|
||||
logger.exception(f'Exception {e}, line : {line}')
|
||||
|
||||
return asr_text
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.SPEECH2TEXT,
|
||||
model_properties={ },
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
287
api/core/model_runtime/model_providers/sagemaker/tts/tts.py
Normal file
287
api/core/model_runtime/model_providers/sagemaker/tts/tts.py
Normal file
@@ -0,0 +1,287 @@
|
||||
import concurrent.futures
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TTSModelType(Enum):
|
||||
PresetVoice = "PresetVoice"
|
||||
CloneVoice = "CloneVoice"
|
||||
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
|
||||
InstructVoice = "InstructVoice"
|
||||
|
||||
class SageMakerText2SpeechModel(TTSModel):
|
||||
|
||||
sagemaker_client: Any = None
|
||||
s3_client : Any = None
|
||||
comprehend_client : Any = None
|
||||
|
||||
def __init__(self):
|
||||
# preset voices, need support custom voice
|
||||
self.model_voices = {
|
||||
'__default': {
|
||||
'all': [
|
||||
{'name': 'Default', 'value': 'default'},
|
||||
]
|
||||
},
|
||||
'CosyVoice': {
|
||||
'zh-Hans': [
|
||||
{'name': '中文男', 'value': '中文男'},
|
||||
{'name': '中文女', 'value': '中文女'},
|
||||
{'name': '粤语女', 'value': '粤语女'},
|
||||
],
|
||||
'zh-Hant': [
|
||||
{'name': '中文男', 'value': '中文男'},
|
||||
{'name': '中文女', 'value': '中文女'},
|
||||
{'name': '粤语女', 'value': '粤语女'},
|
||||
],
|
||||
'en-US': [
|
||||
{'name': '英文男', 'value': '英文男'},
|
||||
{'name': '英文女', 'value': '英文女'},
|
||||
],
|
||||
'ja-JP': [
|
||||
{'name': '日语男', 'value': '日语男'},
|
||||
],
|
||||
'ko-KR': [
|
||||
{'name': '韩语女', 'value': '韩语女'},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def _detect_lang_code(self, content:str, map_dict:dict=None):
|
||||
map_dict = {
|
||||
"zh" : "<|zh|>",
|
||||
"en" : "<|en|>",
|
||||
"ja" : "<|jp|>",
|
||||
"zh-TW" : "<|yue|>",
|
||||
"ko" : "<|ko|>"
|
||||
}
|
||||
|
||||
response = self.comprehend_client.detect_dominant_language(Text=content)
|
||||
language_code = response['Languages'][0]['LanguageCode']
|
||||
|
||||
return map_dict.get(language_code, '<|zh|>')
|
||||
|
||||
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
|
||||
if model_type == TTSModelType.PresetVoice.value and model_role:
|
||||
return { "tts_text" : content_text, "role" : model_role }
|
||||
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
|
||||
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
|
||||
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
|
||||
lang_tag = self._detect_lang_code(content_text)
|
||||
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
|
||||
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
|
||||
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
|
||||
|
||||
raise RuntimeError(f"Invalid params for {model_type}")
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None):
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('aws_access_key_id')
|
||||
secret_key = credentials.get('aws_secret_access_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
self.comprehend_client = boto3.client('comprehend',
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
self.s3_client = boto3.client("s3")
|
||||
self.comprehend_client = boto3.client('comprehend')
|
||||
|
||||
model_type = credentials.get('audio_model_type', 'PresetVoice')
|
||||
prompt_text = credentials.get('prompt_text')
|
||||
prompt_audio = credentials.get('prompt_audio')
|
||||
instruct_text = credentials.get('instruct_text')
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
payload = self._build_tts_payload(
|
||||
model_type,
|
||||
content_text,
|
||||
voice,
|
||||
prompt_text,
|
||||
prompt_audio,
|
||||
instruct_text
|
||||
)
|
||||
|
||||
return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TTS,
|
||||
model_properties={},
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
|
||||
return ""
|
||||
|
||||
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||
return 15
|
||||
|
||||
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
|
||||
return "mp3"
|
||||
|
||||
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||
return 5
|
||||
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||
audio_model_name = 'CosyVoice'
|
||||
for key, voices in self.model_voices.items():
|
||||
if key in audio_model_name:
|
||||
if language and language in voices:
|
||||
return voices[language]
|
||||
elif 'all' in voices:
|
||||
return voices['all']
|
||||
|
||||
return self.model_voices['__default']['all']
|
||||
|
||||
def _invoke_sagemaker(self, payload:dict, endpoint:str):
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=endpoint,
|
||||
Body=json.dumps(payload),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
return json_obj
|
||||
|
||||
def _tts_invoke_streaming(self, model_type:str, payload:dict, sagemaker_endpoint:str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
lang_tag = ''
|
||||
if model_type == TTSModelType.CloneVoice_CrossLingual.value:
|
||||
lang_tag = payload.pop('lang_tag')
|
||||
|
||||
word_limit = self._get_model_word_limit(model='', credentials={})
|
||||
content_text = payload.get("tts_text")
|
||||
if len(content_text) > word_limit:
|
||||
split_sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
|
||||
sentences = [ f"{lang_tag}{s}" for s in split_sentences if len(s) ]
|
||||
len_sent = len(sentences)
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len_sent))
|
||||
payloads = [ copy.deepcopy(payload) for i in range(len_sent) ]
|
||||
for idx in range(len_sent):
|
||||
payloads[idx]["tts_text"] = sentences[idx]
|
||||
|
||||
futures = [ executor.submit(
|
||||
self._invoke_sagemaker,
|
||||
payload=payload,
|
||||
endpoint=sagemaker_endpoint,
|
||||
)
|
||||
for payload in payloads]
|
||||
|
||||
for index, future in enumerate(futures):
|
||||
resp = future.result()
|
||||
audio_bytes = requests.get(resp.get('s3_presign_url')).content
|
||||
for i in range(0, len(audio_bytes), 1024):
|
||||
yield audio_bytes[i:i + 1024]
|
||||
else:
|
||||
resp = self._invoke_sagemaker(payload, sagemaker_endpoint)
|
||||
audio_bytes = requests.get(resp.get('s3_presign_url')).content
|
||||
|
||||
for i in range(0, len(audio_bytes), 1024):
|
||||
yield audio_bytes[i:i + 1024]
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
Reference in New Issue
Block a user