fix: invalid new tool call creation logic during response handling in OAI-Compat model (#17781)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _gen_tool_call_id() -> str:
|
||||
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
|
||||
|
||||
|
||||
def _increase_tool_call(
|
||||
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
|
||||
):
|
||||
"""
|
||||
Merge incremental tool call updates into existing tool calls.
|
||||
|
||||
:param new_tool_calls: List of new tool call deltas to be merged.
|
||||
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
|
||||
"""
|
||||
|
||||
def get_tool_call(tool_call_id: str):
|
||||
"""
|
||||
Get or create a tool call by ID
|
||||
|
||||
:param tool_call_id: tool call ID
|
||||
:return: existing or new tool call
|
||||
"""
|
||||
if not tool_call_id:
|
||||
return existing_tools_calls[-1]
|
||||
|
||||
_tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
|
||||
if _tool_call is None:
|
||||
_tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
existing_tools_calls.append(_tool_call)
|
||||
|
||||
return _tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# generate ID for tool calls with function name but no ID to track them
|
||||
if new_tool_call.function.name and not new_tool_call.id:
|
||||
new_tool_call.id = _gen_tool_call_id()
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.id)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
|
||||
class LargeLanguageModel(AIModel):
|
||||
"""
|
||||
Model class for large language model.
|
||||
@@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel):
|
||||
system_fingerprint = None
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_name: str):
|
||||
if not tool_name:
|
||||
return tools_calls[-1]
|
||||
|
||||
tool_call = next(
|
||||
(tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None
|
||||
)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id="",
|
||||
type="",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.function.name)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
for chunk in result:
|
||||
if isinstance(chunk.delta.message.content, str):
|
||||
content += chunk.delta.message.content
|
||||
elif isinstance(chunk.delta.message.content, list):
|
||||
content_list.extend(chunk.delta.message.content)
|
||||
if chunk.delta.message.tool_calls:
|
||||
increase_tool_call(chunk.delta.message.tool_calls)
|
||||
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||
|
||||
usage = chunk.delta.usage or LLMUsage.empty_usage()
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
|
Reference in New Issue
Block a user