feat: optimize model when app create (#875)

This commit is contained in:
takatost
2023-08-16 22:29:18 +08:00
committed by GitHub
parent cc2d71c253
commit b7c29ea1b6
5 changed files with 104 additions and 23 deletions

View File

@@ -30,8 +30,9 @@ def decrypt_side_effect(tenant_id, encrypted_key):
@patch('huggingface_hub.hf_api.ModelInfo')
def test_hosted_inference_api_is_credentials_valid_or_raise_valid(mock_model_info, mocker):
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc")
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation', cardData={'inference': True})
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value="abc")
mocker.patch('huggingface_hub.hf_api.HfApi.model_info', return_value=mock_model_info.return_value)
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',

View File

@@ -23,14 +23,31 @@ def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def version_effect(id: str):
mock_version = MagicMock()
mock_version.openapi_schema = {
'components': {
'schemas': {
'Output': {
'items': {
'type': 'string'
}
}
}
}
}
return mock_version
@patch('replicate.version.VersionCollection.get', side_effect=version_effect)
def test_is_credentials_valid_or_raise_valid(mocker):
mock_query = MagicMock()
mock_query.return_value = None
mocker.patch('replicate.model.ModelCollection.get', return_value=mock_query)
mocker.patch('replicate.model.Model.versions', return_value=mock_query)
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_name='username/test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=VALIDATE_CREDENTIAL.copy()
)

View File

@@ -26,7 +26,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.tongyi.Tongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]]))
mocker.patch('core.third_party.langchain.llms.tongyi_llm.EnhanceTongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]]))
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)