feat: advanced prompt backend (#1301)
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Type
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, ModelMode
|
||||
from core.model_providers.models.llm.openai_model import OpenAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
@@ -12,7 +12,10 @@ class FakeModelProvider(BaseModelProvider):
|
||||
return 'fake'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return [{'id': 'test_model', 'name': 'Test Model'}]
|
||||
return [{'id': 'test_model', 'name': 'Test Model', 'mode': 'completion'}]
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
return OpenAIModel
|
||||
|
@@ -24,7 +24,7 @@ def test_get_supported_model_list(mocker):
|
||||
provider = FakeModelProvider(provider=Provider())
|
||||
result = provider.get_supported_model_list(ModelType.TEXT_GENERATION)
|
||||
|
||||
assert result == [{'id': 'test_model', 'name': 'test_model'}]
|
||||
assert result == [{'id': 'test_model', 'name': 'test_model', 'mode': 'completion'}]
|
||||
|
||||
|
||||
def test_check_quota_over_limit(mocker):
|
||||
|
Reference in New Issue
Block a user