Feat: support azure openai for temporary (#101)

This commit is contained in:
John Wang
2023-05-19 13:24:45 +08:00
committed by GitHub
parent 3b3c604eb5
commit f68b05d5ec
16 changed files with 350 additions and 109 deletions

View File

@@ -4,9 +4,14 @@ from langchain.callbacks import CallbackManager
from langchain.llms.fake import FakeListLLM
from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from models.provider import ProviderType
class LLMBuilder:
@@ -31,16 +36,23 @@ class LLMBuilder:
if model_name == 'fake':
return FakeListLLM(responses=[])
provider = cls.get_default_provider(tenant_id)
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
# llm_cls = StreamableAzureChatOpenAI
llm_cls = StreamableChatOpenAI
if provider == 'openai':
llm_cls = StreamableChatOpenAI
else:
llm_cls = StreamableAzureChatOpenAI
elif mode == 'completion':
llm_cls = StreamableOpenAI
if provider == 'openai':
llm_cls = StreamableOpenAI
else:
llm_cls = StreamableAzureOpenAI
else:
raise ValueError(f"model name {model_name} is not supported.")
model_credentials = cls.get_model_credentials(tenant_id, model_name)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
return llm_cls(
model_name=model_name,
@@ -86,18 +98,31 @@ class LLMBuilder:
raise ValueError(f"model name {model_name} is not supported.")
@classmethod
def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict:
def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
"""
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
Raises an exception if the model_name is not found or if the provider is not found.
"""
if not model_name:
raise Exception('model name not found')
#
# if model_name not in llm_constant.models:
# raise Exception('model {} not found'.format(model_name))
if model_name not in llm_constant.models:
raise Exception('model {} not found'.format(model_name))
model_provider = llm_constant.models[model_name]
# model_provider = llm_constant.models[model_name]
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
return provider_service.get_credentials(model_name)
@classmethod
def get_default_provider(cls, tenant_id: str) -> str:
provider = BaseProvider.get_valid_provider(tenant_id)
if not provider:
raise ProviderTokenNotInitError()
if provider.provider_type == ProviderType.SYSTEM.value:
provider_name = 'openai'
else:
provider_name = provider.provider_name
return provider_name