feat: mypy for all type check (#10921)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
@@ -53,6 +52,7 @@ logger = logging.getLogger(__name__)
|
||||
class BaseAgentRunner(AppRunner):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
conversation: Conversation,
|
||||
@@ -66,7 +66,7 @@ class BaseAgentRunner(AppRunner):
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
model_instance: ModelInstance,
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
@@ -117,7 +117,7 @@ class BaseAgentRunner(AppRunner):
|
||||
features = model_schema.features if model_schema and model_schema.features else []
|
||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||
self.query = None
|
||||
self.query: Optional[str] = ""
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
def _repack_app_generate_entity(
|
||||
@@ -145,7 +145,7 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
message_tool = PromptMessageTool(
|
||||
name=tool.tool_name,
|
||||
description=tool_entity.description.llm,
|
||||
description=tool_entity.description.llm if tool_entity.description else "",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@@ -167,7 +167,7 @@ class BaseAgentRunner(AppRunner):
|
||||
continue
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
@@ -187,8 +187,8 @@ class BaseAgentRunner(AppRunner):
|
||||
convert dataset retriever tool to prompt message tool
|
||||
"""
|
||||
prompt_tool = PromptMessageTool(
|
||||
name=tool.identity.name,
|
||||
description=tool.description.llm,
|
||||
name=tool.identity.name if tool.identity else "unknown",
|
||||
description=tool.description.llm if tool.description else "",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@@ -210,14 +210,14 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
|
||||
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
|
||||
"""
|
||||
Init tools
|
||||
"""
|
||||
tool_instances = {}
|
||||
prompt_messages_tools = []
|
||||
|
||||
for tool in self.app_config.agent.tools if self.app_config.agent else []:
|
||||
for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
@@ -234,7 +234,8 @@ class BaseAgentRunner(AppRunner):
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
if dataset_tool.identity is not None:
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
|
||||
return tool_instances, prompt_messages_tools
|
||||
|
||||
@@ -258,7 +259,7 @@ class BaseAgentRunner(AppRunner):
|
||||
continue
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
prompt_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
@@ -322,16 +323,21 @@ class BaseAgentRunner(AppRunner):
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, dict],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
observation: Union[str, dict, None],
|
||||
tool_invoke_meta: Union[str, dict, None],
|
||||
answer: str,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None,
|
||||
) -> MessageAgentThought:
|
||||
llm_usage: LLMUsage | None = None,
|
||||
):
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
queried_thought = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
)
|
||||
if not queried_thought:
|
||||
raise ValueError(f"Agent thought {agent_thought.id} not found")
|
||||
agent_thought = queried_thought
|
||||
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
@@ -404,7 +410,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables = (
|
||||
queried_variables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
@@ -412,6 +418,11 @@ class BaseAgentRunner(AppRunner):
|
||||
.first()
|
||||
)
|
||||
|
||||
if not queried_variables:
|
||||
return
|
||||
|
||||
db_variables = queried_variables
|
||||
|
||||
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
@@ -421,7 +432,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
Organize agent history
|
||||
"""
|
||||
result = []
|
||||
result: list[PromptMessage] = []
|
||||
# check if there is a system message in the beginning of the conversation
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
@@ -12,6 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@@ -26,18 +27,18 @@ from models.model import Message
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage] = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] = None
|
||||
_instruction: str = None
|
||||
_query: str = None
|
||||
_prompt_messages_tools: list[PromptMessage] = None
|
||||
_historic_prompt_messages: list[PromptMessage] | None = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] | None = None
|
||||
_instruction: str = "" # FIXME this must be str for now
|
||||
_query: str | None = None
|
||||
_prompt_messages_tools: list[PromptMessageTool] = []
|
||||
|
||||
def run(
|
||||
self,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
inputs: Mapping[str, str],
|
||||
) -> Generator:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
@@ -57,19 +58,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
function_call_state = True
|
||||
llm_usage = {"usage": None}
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
@@ -90,7 +91,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
# the last iteration, remove all tools
|
||||
self._prompt_messages_tools = []
|
||||
|
||||
message_file_ids = []
|
||||
message_file_ids: list[str] = []
|
||||
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
@@ -105,7 +106,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
||||
chunks = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
tools=[],
|
||||
@@ -115,11 +116,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
if not isinstance(chunks, Generator):
|
||||
raise ValueError("Expected streaming response from LLM")
|
||||
|
||||
# check llm result
|
||||
if not chunks:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict = {}
|
||||
usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
@@ -139,25 +143,30 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
if scratchpad.agent_response is not None:
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
scratchpad.action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
scratchpad.agent_response += chunk
|
||||
scratchpad.thought += chunk
|
||||
if scratchpad.agent_response is not None:
|
||||
scratchpad.agent_response += chunk
|
||||
if scratchpad.thought is not None:
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
if scratchpad.thought is not None:
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
if self._agent_scratchpad is not None:
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if "usage" in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict["usage"])
|
||||
if usage_dict["usage"] is not None:
|
||||
increase_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
usage_dict["usage"] = LLMUsage.empty_usage()
|
||||
|
||||
@@ -166,9 +175,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else "",
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought,
|
||||
thought=scratchpad.thought or "",
|
||||
observation="",
|
||||
answer=scratchpad.agent_response,
|
||||
answer=scratchpad.agent_response or "",
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
@@ -209,7 +218,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name,
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
thought=scratchpad.thought,
|
||||
thought=scratchpad.thought or "",
|
||||
observation={scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
answer=scratchpad.agent_response,
|
||||
@@ -247,8 +256,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
if self.variables_pool is not None and self.db_variables_pool is not None:
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@@ -307,8 +316,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
if save_as is not None and self.variables_pool:
|
||||
# FIXME the save_as type is confusing, it should be a string or not
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as))
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
@@ -325,7 +335,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
fill in inputs from external data tools
|
||||
"""
|
||||
@@ -376,11 +386,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpads: list[AgentScratchpadUnit] = []
|
||||
current_scratchpad: AgentScratchpadUnit = None
|
||||
current_scratchpad: AgentScratchpadUnit | None = None
|
||||
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
if not isinstance(message.content, str | None):
|
||||
raise NotImplementedError("expected str type")
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
@@ -399,8 +411,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
except:
|
||||
pass
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
if not current_scratchpad:
|
||||
continue
|
||||
if isinstance(message.content, str):
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
|
@@ -19,7 +19,12 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
"""
|
||||
Organize system prompt
|
||||
"""
|
||||
if not self.app_config.agent:
|
||||
raise ValueError("Agent configuration is not set")
|
||||
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if not prompt_entity:
|
||||
raise ValueError("Agent prompt configuration is not set")
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = (
|
||||
@@ -75,6 +80,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
|
@@ -2,7 +2,12 @@ import json
|
||||
from typing import Optional
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
@@ -11,7 +16,11 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
"""
|
||||
Organize instruction prompt
|
||||
"""
|
||||
if self.app_config.agent is None:
|
||||
raise ValueError("Agent configuration is not set")
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if prompt_entity is None:
|
||||
raise ValueError("prompt entity is not set")
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = (
|
||||
@@ -33,7 +42,13 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
if isinstance(message, UserPromptMessage):
|
||||
historic_prompt += f"Question: {message.content}\n\n"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
historic_prompt += message.content + "\n\n"
|
||||
if isinstance(message.content, str):
|
||||
historic_prompt += message.content + "\n\n"
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
historic_prompt += content.data
|
||||
|
||||
return historic_prompt
|
||||
|
||||
@@ -50,7 +65,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
assistant_prompt = ""
|
||||
for unit in agent_scratchpad:
|
||||
for unit in agent_scratchpad or []:
|
||||
if unit.is_final():
|
||||
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
|
@@ -78,5 +78,5 @@ class AgentEntity(BaseModel):
|
||||
model: str
|
||||
strategy: Strategy
|
||||
prompt: Optional[AgentPromptEntity] = None
|
||||
tools: list[AgentToolEntity] = None
|
||||
tools: list[AgentToolEntity] | None = None
|
||||
max_iteration: int = 5
|
||||
|
@@ -40,6 +40,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
@@ -49,7 +51,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage = {"usage": None}
|
||||
llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
|
||||
final_answer = ""
|
||||
|
||||
# get tracing instance
|
||||
@@ -75,7 +77,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
|
||||
message_file_ids = []
|
||||
message_file_ids: list[str] = []
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
@@ -105,7 +107,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
if self.stream_tool_call:
|
||||
if self.stream_tool_call and isinstance(chunks, Generator):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
@@ -116,7 +118,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk))
|
||||
tool_calls.extend(self.extract_tool_calls(chunk) or [])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps(
|
||||
@@ -131,19 +133,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
for content in chunk.delta.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += chunk.delta.message.content
|
||||
response += str(chunk.delta.message.content)
|
||||
|
||||
if chunk.delta.usage:
|
||||
increase_usage(llm_usage, chunk.delta.usage)
|
||||
current_llm_usage = chunk.delta.usage
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
result: LLMResult = chunks
|
||||
elif not self.stream_tool_call and isinstance(chunks, LLMResult):
|
||||
result = chunks
|
||||
# check if there is any tool call
|
||||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result))
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps(
|
||||
@@ -162,7 +164,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
for content in result.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += result.message.content
|
||||
response += str(result.message.content)
|
||||
|
||||
if not result.message.content:
|
||||
result.message.content = ""
|
||||
@@ -181,6 +183,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
usage=result.usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"invalid chunks type: {type(chunks)}")
|
||||
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if tool_calls:
|
||||
@@ -243,7 +247,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
if self.variables_pool:
|
||||
self.variables_pool.set_file(
|
||||
tool_name=tool_call_name, value=message_file_id, name=save_as
|
||||
)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
@@ -263,7 +270,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
if tool_response["tool_response"] is not None:
|
||||
self._current_thoughts.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response["tool_response"],
|
||||
content=str(tool_response["tool_response"]),
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
@@ -273,9 +280,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
thought="",
|
||||
tool_invoke_meta={
|
||||
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
|
||||
},
|
||||
@@ -283,7 +290,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_response["tool_call_name"]: tool_response["tool_response"]
|
||||
for tool_response in tool_responses
|
||||
},
|
||||
answer=None,
|
||||
answer="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
@@ -296,7 +303,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
if self.variables_pool and self.db_variables_pool:
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@@ -389,9 +397,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
|
||||
return prompt_messages
|
||||
return prompt_messages or []
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
@@ -449,7 +457,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
def _organize_prompt_messages(self):
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query, [])
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
|
@@ -38,7 +38,7 @@ class CotAgentOutputParser:
|
||||
except:
|
||||
return json_str or ""
|
||||
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return
|
||||
@@ -67,15 +67,15 @@ class CotAgentOutputParser:
|
||||
for response in llm_response:
|
||||
if response.delta.usage:
|
||||
usage_dict["usage"] = response.delta.usage
|
||||
response = response.delta.message.content
|
||||
if not isinstance(response, str):
|
||||
response_content = response.delta.message.content
|
||||
if not isinstance(response_content, str):
|
||||
continue
|
||||
|
||||
# stream
|
||||
index = 0
|
||||
while index < len(response):
|
||||
while index < len(response_content):
|
||||
steps = 1
|
||||
delta = response[index : index + steps]
|
||||
delta = response_content[index : index + steps]
|
||||
yield_delta = False
|
||||
|
||||
if delta == "`":
|
||||
|
Reference in New Issue
Block a user