feat: support json schema for gemini models (#10835)

This commit is contained in:
非法操作
2024-11-19 17:49:58 +08:00
committed by GitHub
parent 9f195df103
commit bc1013dacf
18 changed files with 61 additions and 77 deletions

View File

@@ -31,7 +31,7 @@ def test_invoke_model(setup_google_mock):
model = GoogleLargeLanguageModel()
response = model.invoke(
model="gemini-pro",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
@@ -48,7 +48,7 @@ def test_invoke_model(setup_google_mock):
]
),
],
model_parameters={"temperature": 0.5, "top_p": 1.0, "max_tokens_to_sample": 2048},
model_parameters={"temperature": 0.5, "top_p": 1.0, "max_output_tokens": 2048},
stop=["How"],
stream=False,
user="abc-123",
@@ -63,7 +63,7 @@ def test_invoke_stream_model(setup_google_mock):
model = GoogleLargeLanguageModel()
response = model.invoke(
model="gemini-pro",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
@@ -80,7 +80,7 @@ def test_invoke_stream_model(setup_google_mock):
]
),
],
model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens_to_sample": 2048},
model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens": 2048},
stream=True,
user="abc-123",
)
@@ -99,7 +99,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
model = GoogleLargeLanguageModel()
result = model.invoke(
model="gemini-pro-vision",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
@@ -128,7 +128,7 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
model = GoogleLargeLanguageModel()
result = model.invoke(
model="gemini-pro-vision",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(content="You are a helpful AI assistant."),
@@ -164,7 +164,7 @@ def test_get_num_tokens():
model = GoogleLargeLanguageModel()
num_tokens = model.get_num_tokens(
model="gemini-pro",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(