feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -40,6 +40,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
app_generate_entity = self.application_generate_entity
app_config = self.app_config
assert app_config is not None, "app_config is required"
assert app_config.agent is not None, "app_config.agent is required"
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
@@ -49,7 +51,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# continue to run until there is not any tool call
function_call_state = True
llm_usage = {"usage": None}
llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
final_answer = ""
# get tracing instance
@@ -75,7 +77,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# the last iteration, remove all tools
prompt_messages_tools = []
message_file_ids = []
message_file_ids: list[str] = []
agent_thought = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
@@ -105,7 +107,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
current_llm_usage = None
if self.stream_tool_call:
if self.stream_tool_call and isinstance(chunks, Generator):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
@@ -116,7 +118,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_calls.extend(self.extract_tool_calls(chunk) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
@@ -131,19 +133,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for content in chunk.delta.message.content:
response += content.data
else:
response += chunk.delta.message.content
response += str(chunk.delta.message.content)
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
yield chunk
else:
result: LLMResult = chunks
elif not self.stream_tool_call and isinstance(chunks, LLMResult):
result = chunks
# check if there is any tool call
if self.check_blocking_tool_calls(result):
function_call_state = True
tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
@@ -162,7 +164,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for content in result.message.content:
response += content.data
else:
response += result.message.content
response += str(result.message.content)
if not result.message.content:
result.message.content = ""
@@ -181,6 +183,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
usage=result.usage,
),
)
else:
raise RuntimeError(f"invalid chunks type: {type(chunks)}")
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls:
@@ -243,7 +247,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# publish files
for message_file_id, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
if self.variables_pool:
self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file_id, name=save_as
)
# publish message file
self.queue_manager.publish(
@@ -263,7 +270,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if tool_response["tool_response"] is not None:
self._current_thoughts.append(
ToolPromptMessage(
content=tool_response["tool_response"],
content=str(tool_response["tool_response"]),
tool_call_id=tool_call_id,
name=tool_call_name,
)
@@ -273,9 +280,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=None,
tool_input=None,
thought=None,
tool_name="",
tool_input="",
thought="",
tool_invoke_meta={
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
},
@@ -283,7 +290,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_response["tool_call_name"]: tool_response["tool_response"]
for tool_response in tool_responses
},
answer=None,
answer="",
messages_ids=message_file_ids,
)
self.queue_manager.publish(
@@ -296,7 +303,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
iteration_step += 1
self.update_db_variables(self.variables_pool, self.db_variables_pool)
if self.variables_pool and self.db_variables_pool:
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
@@ -389,9 +397,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
return prompt_messages or []
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
@@ -449,7 +457,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query, [])
query_prompt_messages = self._organize_user_query(self.query or "", [])
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,