Feat/zhipuai function calling (#2199)
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
@@ -3,7 +3,8 @@ from typing import Generator
|
||||
|
||||
import pytest
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, SystemPromptMessage,
|
||||
UserPromptMessage, PromptMessageTool)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel
|
||||
|
||||
@@ -102,3 +103,48 @@ def test_get_num_tokens():
|
||||
)
|
||||
|
||||
assert num_tokens == 14
|
||||
|
||||
def test_get_tools_num_tokens():
|
||||
model = ZhipuAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='tools',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
},
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
)
|
||||
],
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 108
|
@@ -42,7 +42,7 @@ def test_invoke_model():
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
assert result.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
|
Reference in New Issue
Block a user