feat:use xinference tts stream mode (#8616)

This commit is contained in:
呆萌闷油瓶
2024-09-22 10:08:35 +08:00
committed by GitHub
parent a587f0d3f1
commit c8b9bdebfe
5 changed files with 13 additions and 17 deletions

View File

@@ -19,7 +19,6 @@ from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.completion import Completion
from xinference_client.client.restful.restful_client import (
Client,
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulGenerateModelHandle,
)
@@ -491,7 +490,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
if tools and len(tools) > 0:
generate_config["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools]
vision = credentials.get("support_vision", False)
if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
if isinstance(xinference_model, RESTfulChatModelHandle):
resp = client.chat.completions.create(
model=credentials["model_uid"],
messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],

View File

@@ -208,21 +208,21 @@ class XinferenceText2SpeechModel(TTSModel):
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
futures = [
executor.submit(
handle.speech, input=sentences[i], voice=voice, response_format="mp3", speed=1.0, stream=False
handle.speech, input=sentences[i], voice=voice, response_format="mp3", speed=1.0, stream=True
)
for i in range(len(sentences))
]
for future in futures:
response = future.result()
for i in range(0, len(response), 1024):
yield response[i : i + 1024]
for chunk in response:
yield chunk
else:
response = handle.speech(
input=content_text.strip(), voice=voice, response_format="mp3", speed=1.0, stream=False
input=content_text.strip(), voice=voice, response_format="mp3", speed=1.0, stream=True
)
for i in range(0, len(response), 1024):
yield response[i : i + 1024]
for chunk in response:
yield chunk
except Exception as ex:
raise InvokeBadRequestError(str(ex))