feat: advanced prompt backend (#1301)
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
@@ -44,7 +44,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('claude-2')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 6
|
||||
|
||||
|
@@ -69,7 +69,7 @@ def test_chat_get_num_tokens(mock_decrypt, mocker):
|
||||
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
|
||||
rst = openai_model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 22
|
||||
|
||||
|
@@ -48,7 +48,7 @@ def test_chat_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('baichuan2-53b')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst > 0
|
||||
|
||||
@@ -59,7 +59,7 @@ def test_chat_run(mock_decrypt, mocker):
|
||||
|
||||
model = get_mock_model('baichuan2-53b')
|
||||
messages = [
|
||||
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
]
|
||||
rst = model.run(
|
||||
messages,
|
||||
@@ -73,7 +73,7 @@ def test_chat_stream_run(mock_decrypt, mocker):
|
||||
|
||||
model = get_mock_model('baichuan2-53b', streaming=True)
|
||||
messages = [
|
||||
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
]
|
||||
rst = model.run(
|
||||
messages
|
||||
|
@@ -71,7 +71,7 @@ def test_hosted_inference_api_get_num_tokens(mock_decrypt, mock_model_info, mock
|
||||
mocker
|
||||
)
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
@@ -88,7 +88,7 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke
|
||||
mocker
|
||||
)
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
|
@@ -48,7 +48,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('abab5.5-chat')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
|
@@ -52,7 +52,7 @@ def test_chat_get_num_tokens(mock_decrypt):
|
||||
openai_model = get_mock_openai_model('gpt-3.5-turbo')
|
||||
rst = openai_model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 22
|
||||
|
||||
|
@@ -55,7 +55,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt, mocker):
|
||||
model = get_mock_model('facebook/opt-125m', mocker)
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
|
@@ -58,7 +58,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt, mocker):
|
||||
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 7
|
||||
|
||||
|
@@ -52,7 +52,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('spark')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 6
|
||||
|
||||
|
@@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('qwen-turbo')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
|
@@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('ernie-bot')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
|
@@ -57,7 +57,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt, mocker):
|
||||
model = get_mock_model('llama-2-chat', mocker)
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
|
@@ -46,7 +46,7 @@ def test_chat_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('chatglm_lite')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst > 0
|
||||
|
||||
@@ -57,7 +57,7 @@ def test_chat_run(mock_decrypt, mocker):
|
||||
|
||||
model = get_mock_model('chatglm_lite')
|
||||
messages = [
|
||||
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
]
|
||||
rst = model.run(
|
||||
messages,
|
||||
@@ -71,7 +71,7 @@ def test_chat_stream_run(mock_decrypt, mocker):
|
||||
|
||||
model = get_mock_model('chatglm_lite', streaming=True)
|
||||
messages = [
|
||||
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
]
|
||||
rst = model.run(
|
||||
messages
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from typing import Type
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, ModelMode
|
||||
from core.model_providers.models.llm.openai_model import OpenAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
@@ -12,7 +12,10 @@ class FakeModelProvider(BaseModelProvider):
|
||||
return 'fake'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return [{'id': 'test_model', 'name': 'Test Model'}]
|
||||
return [{'id': 'test_model', 'name': 'Test Model', 'mode': 'completion'}]
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
return OpenAIModel
|
||||
|
@@ -24,7 +24,7 @@ def test_get_supported_model_list(mocker):
|
||||
provider = FakeModelProvider(provider=Provider())
|
||||
result = provider.get_supported_model_list(ModelType.TEXT_GENERATION)
|
||||
|
||||
assert result == [{'id': 'test_model', 'name': 'test_model'}]
|
||||
assert result == [{'id': 'test_model', 'name': 'test_model', 'mode': 'completion'}]
|
||||
|
||||
|
||||
def test_check_quota_over_limit(mocker):
|
||||
|
Reference in New Issue
Block a user