Feat: optimize error desc (#152)
This commit is contained in:
@@ -1,22 +1,24 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class AzureProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id) if not credentials else credentials
|
||||
url = "{}/openai/deployments?api-version={}".format(
|
||||
credentials.get('openai_api_base'),
|
||||
credentials.get('openai_api_version')
|
||||
str(credentials.get('openai_api_base')),
|
||||
str(credentials.get('openai_api_version'))
|
||||
)
|
||||
|
||||
headers = {
|
||||
"api-key": credentials.get('openai_api_key'),
|
||||
"api-key": str(credentials.get('openai_api_key')),
|
||||
"content-type": "application/json; charset=utf-8"
|
||||
}
|
||||
|
||||
@@ -29,8 +31,10 @@ class AzureProvider(BaseProvider):
|
||||
'name': '{} ({})'.format(deployment['id'], deployment['model'])
|
||||
} for deployment in result['data'] if deployment['status'] == 'succeeded']
|
||||
else:
|
||||
# TODO: optimize in future
|
||||
raise Exception('Failed to get deployments from Azure OpenAI. Status code: {}'.format(response.status_code))
|
||||
if response.status_code == 401:
|
||||
raise AzureAuthenticationError()
|
||||
else:
|
||||
raise AzureRequestFailedError('Failed to request Azure OpenAI. Status code: {}'.format(response.status_code))
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
@@ -38,7 +42,7 @@ class AzureProvider(BaseProvider):
|
||||
"""
|
||||
config = self.get_provider_api_key(model_id=model_id)
|
||||
config['openai_api_type'] = 'azure'
|
||||
config['deployment_name'] = model_id.replace('.', '')
|
||||
config['deployment_name'] = model_id.replace('.', '') if model_id else None
|
||||
return config
|
||||
|
||||
def get_provider_name(self):
|
||||
@@ -54,7 +58,7 @@ class AzureProvider(BaseProvider):
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
|
||||
'openai_api_base': '',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
||||
@@ -63,7 +67,7 @@ class AzureProvider(BaseProvider):
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
|
||||
'openai_api_base': '',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
||||
@@ -80,8 +84,23 @@ class AzureProvider(BaseProvider):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
# TODO: implement
|
||||
pass
|
||||
try:
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError('Config must be a object.')
|
||||
|
||||
if 'openai_api_version' not in config:
|
||||
config['openai_api_version'] = '2023-03-15-preview'
|
||||
|
||||
self.get_models(credentials=config)
|
||||
except AzureAuthenticationError:
|
||||
raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Key.')
|
||||
except requests.ConnectionError:
|
||||
raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Base Endpoint.')
|
||||
except AzureRequestFailedError as ex:
|
||||
raise ValidateFailedError('Azure OpenAI Credentials validation failed, error: {}.'.format(str(ex)))
|
||||
except Exception as ex:
|
||||
logging.exception('Azure OpenAI Credentials validation failed')
|
||||
raise ex
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
"""
|
||||
@@ -101,3 +120,11 @@ class AzureProvider(BaseProvider):
|
||||
config = json.loads(token)
|
||||
config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
|
||||
return config
|
||||
|
||||
|
||||
class AzureAuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AzureRequestFailedError(Exception):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user