feat: mypy for all type check (#10921)
This commit is contained in:
@@ -40,6 +40,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
@@ -49,7 +51,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage = {"usage": None}
|
||||
llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
|
||||
final_answer = ""
|
||||
|
||||
# get tracing instance
|
||||
@@ -75,7 +77,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
|
||||
message_file_ids = []
|
||||
message_file_ids: list[str] = []
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
@@ -105,7 +107,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
if self.stream_tool_call:
|
||||
if self.stream_tool_call and isinstance(chunks, Generator):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
@@ -116,7 +118,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk))
|
||||
tool_calls.extend(self.extract_tool_calls(chunk) or [])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps(
|
||||
@@ -131,19 +133,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
for content in chunk.delta.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += chunk.delta.message.content
|
||||
response += str(chunk.delta.message.content)
|
||||
|
||||
if chunk.delta.usage:
|
||||
increase_usage(llm_usage, chunk.delta.usage)
|
||||
current_llm_usage = chunk.delta.usage
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
result: LLMResult = chunks
|
||||
elif not self.stream_tool_call and isinstance(chunks, LLMResult):
|
||||
result = chunks
|
||||
# check if there is any tool call
|
||||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result))
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps(
|
||||
@@ -162,7 +164,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
for content in result.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += result.message.content
|
||||
response += str(result.message.content)
|
||||
|
||||
if not result.message.content:
|
||||
result.message.content = ""
|
||||
@@ -181,6 +183,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
usage=result.usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"invalid chunks type: {type(chunks)}")
|
||||
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if tool_calls:
|
||||
@@ -243,7 +247,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
if self.variables_pool:
|
||||
self.variables_pool.set_file(
|
||||
tool_name=tool_call_name, value=message_file_id, name=save_as
|
||||
)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
@@ -263,7 +270,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
if tool_response["tool_response"] is not None:
|
||||
self._current_thoughts.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response["tool_response"],
|
||||
content=str(tool_response["tool_response"]),
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
@@ -273,9 +280,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
thought="",
|
||||
tool_invoke_meta={
|
||||
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
|
||||
},
|
||||
@@ -283,7 +290,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_response["tool_call_name"]: tool_response["tool_response"]
|
||||
for tool_response in tool_responses
|
||||
},
|
||||
answer=None,
|
||||
answer="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
@@ -296,7 +303,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
if self.variables_pool and self.db_variables_pool:
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@@ -389,9 +397,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
|
||||
return prompt_messages
|
||||
return prompt_messages or []
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
@@ -449,7 +457,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
def _organize_prompt_messages(self):
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query, [])
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
|
Reference in New Issue
Block a user