feat: support more model types and builtin tools on aws/sagemaker (#8061)

Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
This commit is contained in:
ybalbert001
2024-09-09 10:34:11 +08:00
committed by GitHub
parent ab7d79275e
commit 954580a4af
17 changed files with 1452 additions and 72 deletions

View File

@@ -1,17 +1,36 @@
import json
import logging
from collections.abc import Generator
from typing import Any, Optional, Union
import re
from collections.abc import Generator, Iterator
from typing import Any, Optional, Union, cast
# from openai.types.chat import ChatCompletion, ChatCompletionChunk
import boto3
from sagemaker import Predictor, serializers
from sagemaker.session import Session
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
I18nObject,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@@ -25,12 +44,140 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
logger = logging.getLogger(__name__)
def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], stop:list, stream=False):
"""
params:
predictor : Sagemaker Predictor
messages (List[Dict[str,Any]]): message list。
messages = [
{"role": "system", "content":"please answer in Chinese"},
{"role": "user", "content": "who are you? what are you doing?"},
]
params (Dict[str,Any]): model parameters for LLM。
stream (bool): False by default。
response:
result of inference if stream is False
Iterator of Chunks if stream is True
"""
payload = {
"model" : params.get('model_name'),
"stop" : stop,
"messages": messages,
"stream" : stream,
"max_tokens" : params.get('max_new_tokens', params.get('max_tokens', 2048)),
"temperature" : params.get('temperature', 0.1),
"top_p" : params.get('top_p', 0.9),
}
if not stream:
response = predictor.predict(payload)
return response
else:
response_stream = predictor.predict_stream(payload)
return response_stream
class SageMakerLargeLanguageModel(LargeLanguageModel):
"""
Model class for Cohere large language model.
"""
sagemaker_client: Any = None
sagemaker_sess : Any = None
predictor : Any = None
def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
resp: bytes) -> LLMResult:
"""
handle normal chat generate response
"""
resp_obj = json.loads(resp.decode('utf-8'))
resp_str = resp_obj.get('choices')[0].get('message').get('content')
if len(resp_str) == 0:
raise InvokeServerUnavailableError("Empty response")
assistant_prompt_message = AssistantPromptMessage(
content=resp_str,
tool_calls=[]
)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens)
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
usage=usage,
message=assistant_prompt_message,
)
return response
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
resp: Iterator[bytes]) -> Generator:
"""
handle stream chat generate response
"""
full_response = ''
buffer = ""
for chunk_bytes in resp:
buffer += chunk_bytes.decode('utf-8')
last_idx = 0
for match in re.finditer(r'^data:\s*(.+?)(\n\n)', buffer):
try:
data = json.loads(match.group(1).strip())
last_idx = match.span()[1]
if "content" in data["choices"][0]["delta"]:
chunk_content = data["choices"][0]["delta"]["content"]
assistant_prompt_message = AssistantPromptMessage(
content=chunk_content,
tool_calls=[]
)
if data["choices"][0]['finish_reason'] is not None:
temp_assistant_prompt_message = AssistantPromptMessage(
content=full_response,
tool_calls=[]
)
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
delta=LLMResultChunkDelta(
index=0,
message=assistant_prompt_message,
finish_reason=data["choices"][0]['finish_reason'],
usage=usage
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
delta=LLMResultChunkDelta(
index=0,
message=assistant_prompt_message
),
)
full_response += chunk_content
except (json.JSONDecodeError, KeyError, IndexError) as e:
logger.info("json parse exception, content: {}".format(match.group(1).strip()))
pass
buffer = buffer[last_idx:]
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
@@ -50,9 +197,6 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# get model mode
model_mode = self.get_model_mode(model, credentials)
if not self.sagemaker_client:
access_key = credentials.get('access_key')
secret_key = credentials.get('secret_key')
@@ -68,37 +212,132 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client)
self.predictor = Predictor(
endpoint_name=credentials.get('sagemaker_endpoint'),
sagemaker_session=sagemaker_session,
serializer=serializers.JSONSerializer(),
)
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=sagemaker_endpoint,
Body=json.dumps(
{
"inputs": prompt_messages[0].content,
"parameters": { "stop" : stop},
"history" : []
}
),
ContentType="application/json",
)
assistant_text = response_model['Body'].read().decode('utf8')
messages:list[dict[str,Any]] = [ {"role": p.role.value, "content": p.content} for p in prompt_messages ]
response = inference(predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
if stream:
if tools and len(tools) > 0:
raise InvokeBadRequestError(f"{model}'s tool calls does not support stream mode")
usage = self._calc_response_usage(model, credentials, 0, 0)
return self._handle_chat_stream_response(model=model, credentials=credentials,
prompt_messages=prompt_messages,
tools=tools, resp=response)
return self._handle_chat_generate_response(model=model, credentials=credentials,
prompt_messages=prompt_messages,
tools=tools, resp=response)
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage
)
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for OpenAI Compatibility API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {
"url": message_content.data,
"detail": message_content.detail.value
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return response
return message_dict
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
is_completion_model: bool = False) -> int:
def tokens(text: str):
return self._get_num_tokens_by_gpt2(text)
if is_completion_model:
return sum(tokens(str(message.content)) for message in messages)
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, list):
text = ''
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
value = text
if key == "tool_calls":
for tool_call in value:
for t_key, t_value in tool_call.items():
num_tokens += tokens(t_key)
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += tokens(f_key)
num_tokens += tokens(f_value)
else:
num_tokens += tokens(t_key)
num_tokens += tokens(t_value)
if key == "function_call":
for t_key, t_value in value.items():
num_tokens += tokens(t_key)
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += tokens(f_key)
num_tokens += tokens(f_value)
else:
num_tokens += tokens(t_key)
num_tokens += tokens(t_value)
else:
num_tokens += tokens(str(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3
if tools:
num_tokens += self._num_tokens_for_tools(tools)
return num_tokens
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
@@ -112,10 +351,8 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
:return:
"""
# get model mode
model_mode = self.get_model_mode(model)
try:
return 0
return self._num_tokens_from_messages(prompt_messages, tools)
except Exception as e:
raise self._transform_invoke_error(e)
@@ -129,7 +366,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
"""
try:
# get model mode
model_mode = self.get_model_mode(model)
pass
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@@ -200,13 +437,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
)
]
completion_type = LLMMode.value_of(credentials["mode"])
if completion_type == LLMMode.CHAT:
print(f"completion_type : {LLMMode.CHAT.value}")
if completion_type == LLMMode.COMPLETION:
print(f"completion_type : {LLMMode.COMPLETION.value}")
completion_type = LLMMode.value_of(credentials["mode"]).value
features = []

View File

@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
class SageMakerRerankModel(RerankModel):
"""
Model class for Cohere rerank model.
Model class for SageMaker rerank model.
"""
sagemaker_client: Any = None

View File

@@ -1,10 +1,11 @@
import logging
import uuid
from typing import IO, Any
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class SageMakerProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
@@ -15,3 +16,28 @@ class SageMakerProvider(ModelProvider):
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
pass
def buffer_to_s3(s3_client:Any, file: IO[bytes], bucket:str, s3_prefix:str) -> str:
'''
return s3_uri of this file
'''
s3_key = f'{s3_prefix}{uuid.uuid4()}.mp3'
s3_client.put_object(
Body=file.read(),
Bucket=bucket,
Key=s3_key,
ContentType='audio/mp3'
)
return s3_key
def generate_presigned_url(s3_client:Any, file: IO[bytes], bucket_name:str, s3_prefix:str, expiration=600) -> str:
object_key = buffer_to_s3(s3_client, file, bucket_name, s3_prefix)
try:
response = s3_client.generate_presigned_url('get_object',
Params={'Bucket': bucket_name, 'Key': object_key},
ExpiresIn=expiration)
except Exception as e:
print(f"Error generating presigned URL: {e}")
return None
return response

View File

@@ -21,6 +21,8 @@ supported_model_types:
- llm
- text-embedding
- rerank
- speech2text
- tts
configurate_methods:
- customizable-model
model_credential_schema:
@@ -45,14 +47,10 @@ model_credential_schema:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
zh_Hans: Chat
- variable: sagemaker_endpoint
label:
en_US: sagemaker endpoint
@@ -61,6 +59,76 @@ model_credential_schema:
placeholder:
zh_Hans: 请输出你的Sagemaker推理端点
en_US: Enter your Sagemaker Inference endpoint
- variable: audio_s3_cache_bucket
show_on:
- variable: __model_type
value: speech2text
label:
zh_Hans: 音频缓存桶(s3 bucket)
en_US: audio cache bucket(s3 bucket)
type: text-input
required: true
placeholder:
zh_Hans: sagemaker-us-east-1-******207838
en_US: sagemaker-us-east-1-*******7838
- variable: audio_model_type
show_on:
- variable: __model_type
value: tts
label:
en_US: Audio model type
type: select
required: true
placeholder:
zh_Hans: 语音模型类型
en_US: Audio model type
options:
- value: PresetVoice
label:
en_US: preset voice
zh_Hans: 内置音色
- value: CloneVoice
label:
en_US: clone voice
zh_Hans: 克隆音色
- value: CloneVoice_CrossLingual
label:
en_US: crosslingual clone voice
zh_Hans: 跨语种克隆音色
- value: InstructVoice
label:
en_US: Instruct voice
zh_Hans: 文字指令音色
- variable: prompt_audio
show_on:
- variable: __model_type
value: tts
label:
en_US: Mock Audio Source
type: text-input
required: false
placeholder:
zh_Hans: 被模仿的音色音频
en_US: source audio to be mocked
- variable: prompt_text
show_on:
- variable: __model_type
value: tts
label:
en_US: Prompt Audio Text
type: text-input
required: false
placeholder:
zh_Hans: 模仿音色的对应文本
en_US: text for the mocked source audio
- variable: instruct_text
show_on:
- variable: __model_type
value: tts
label:
en_US: instruct text for speaker
type: text-input
required: false
- variable: aws_access_key_id
required: false
label:

View File

@@ -0,0 +1,142 @@
import json
import logging
from typing import IO, Any, Optional
import boto3
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url
logger = logging.getLogger(__name__)
class SageMakerSpeech2TextModel(Speech2TextModel):
"""
Model class for Xinference speech to text model.
"""
sagemaker_client: Any = None
s3_client : Any = None
def _invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
asr_text = None
try:
if not self.sagemaker_client:
access_key = credentials.get('aws_access_key_id')
secret_key = credentials.get('aws_secret_access_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
self.s3_client = boto3.client("s3",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
self.s3_client = boto3.client("s3", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
self.s3_client = boto3.client("s3")
s3_prefix='dify/speech2text/'
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
bucket = credentials.get('audio_s3_cache_bucket')
s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix)
payload = {
"audio_s3_presign_uri" : s3_presign_url
}
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=sagemaker_endpoint,
Body=json.dumps(payload),
ContentType="application/json"
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
asr_text = json_obj['text']
except Exception as e:
logger.exception(f'Exception {e}, line : {line}')
return asr_text
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
pass
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.SPEECH2TEXT,
model_properties={ },
parameter_rules=[]
)
return entity

View File

@@ -0,0 +1,287 @@
import concurrent.futures
import copy
import json
import logging
from enum import Enum
from typing import Any, Optional
import boto3
import requests
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.tts_model import TTSModel
logger = logging.getLogger(__name__)
class TTSModelType(Enum):
PresetVoice = "PresetVoice"
CloneVoice = "CloneVoice"
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
InstructVoice = "InstructVoice"
class SageMakerText2SpeechModel(TTSModel):
sagemaker_client: Any = None
s3_client : Any = None
comprehend_client : Any = None
def __init__(self):
# preset voices, need support custom voice
self.model_voices = {
'__default': {
'all': [
{'name': 'Default', 'value': 'default'},
]
},
'CosyVoice': {
'zh-Hans': [
{'name': '中文男', 'value': '中文男'},
{'name': '中文女', 'value': '中文女'},
{'name': '粤语女', 'value': '粤语女'},
],
'zh-Hant': [
{'name': '中文男', 'value': '中文男'},
{'name': '中文女', 'value': '中文女'},
{'name': '粤语女', 'value': '粤语女'},
],
'en-US': [
{'name': '英文男', 'value': '英文男'},
{'name': '英文女', 'value': '英文女'},
],
'ja-JP': [
{'name': '日语男', 'value': '日语男'},
],
'ko-KR': [
{'name': '韩语女', 'value': '韩语女'},
]
}
}
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
pass
def _detect_lang_code(self, content:str, map_dict:dict=None):
map_dict = {
"zh" : "<|zh|>",
"en" : "<|en|>",
"ja" : "<|jp|>",
"zh-TW" : "<|yue|>",
"ko" : "<|ko|>"
}
response = self.comprehend_client.detect_dominant_language(Text=content)
language_code = response['Languages'][0]['LanguageCode']
return map_dict.get(language_code, '<|zh|>')
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
if model_type == TTSModelType.PresetVoice.value and model_role:
return { "tts_text" : content_text, "role" : model_role }
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
lang_tag = self._detect_lang_code(content_text)
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
raise RuntimeError(f"Invalid params for {model_type}")
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
user: Optional[str] = None):
"""
_invoke text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:param user: unique user id
:return: text translated to audio file
"""
if not self.sagemaker_client:
access_key = credentials.get('aws_access_key_id')
secret_key = credentials.get('aws_secret_access_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
self.s3_client = boto3.client("s3",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
self.comprehend_client = boto3.client('comprehend',
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
self.s3_client = boto3.client("s3", region_name=aws_region)
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
self.s3_client = boto3.client("s3")
self.comprehend_client = boto3.client('comprehend')
model_type = credentials.get('audio_model_type', 'PresetVoice')
prompt_text = credentials.get('prompt_text')
prompt_audio = credentials.get('prompt_audio')
instruct_text = credentials.get('instruct_text')
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
payload = self._build_tts_payload(
model_type,
content_text,
voice,
prompt_text,
prompt_audio,
instruct_text
)
return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TTS,
model_properties={},
parameter_rules=[]
)
return entity
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
return ""
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
return 15
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
return "mp3"
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
return 5
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
audio_model_name = 'CosyVoice'
for key, voices in self.model_voices.items():
if key in audio_model_name:
if language and language in voices:
return voices[language]
elif 'all' in voices:
return voices['all']
return self.model_voices['__default']['all']
def _invoke_sagemaker(self, payload:dict, endpoint:str):
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=endpoint,
Body=json.dumps(payload),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
return json_obj
def _tts_invoke_streaming(self, model_type:str, payload:dict, sagemaker_endpoint:str) -> any:
"""
_tts_invoke_streaming text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file
"""
try:
lang_tag = ''
if model_type == TTSModelType.CloneVoice_CrossLingual.value:
lang_tag = payload.pop('lang_tag')
word_limit = self._get_model_word_limit(model='', credentials={})
content_text = payload.get("tts_text")
if len(content_text) > word_limit:
split_sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
sentences = [ f"{lang_tag}{s}" for s in split_sentences if len(s) ]
len_sent = len(sentences)
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len_sent))
payloads = [ copy.deepcopy(payload) for i in range(len_sent) ]
for idx in range(len_sent):
payloads[idx]["tts_text"] = sentences[idx]
futures = [ executor.submit(
self._invoke_sagemaker,
payload=payload,
endpoint=sagemaker_endpoint,
)
for payload in payloads]
for index, future in enumerate(futures):
resp = future.result()
audio_bytes = requests.get(resp.get('s3_presign_url')).content
for i in range(0, len(audio_bytes), 1024):
yield audio_bytes[i:i + 1024]
else:
resp = self._invoke_sagemaker(payload, sagemaker_endpoint)
audio_bytes = requests.get(resp.get('s3_presign_url')).content
for i in range(0, len(audio_bytes), 1024):
yield audio_bytes[i:i + 1024]
except Exception as ex:
raise InvokeBadRequestError(str(ex))