feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,7 +1,6 @@
import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Optional, Union, cast
@@ -53,6 +52,7 @@ logger = logging.getLogger(__name__)
class BaseAgentRunner(AppRunner):
def __init__(
self,
*,
tenant_id: str,
application_generate_entity: AgentChatAppGenerateEntity,
conversation: Conversation,
@@ -66,7 +66,7 @@ class BaseAgentRunner(AppRunner):
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance | None = None,
model_instance: ModelInstance,
) -> None:
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
@@ -117,7 +117,7 @@ class BaseAgentRunner(AppRunner):
features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
self.query = None
self.query: Optional[str] = ""
self._current_thoughts: list[PromptMessage] = []
def _repack_app_generate_entity(
@@ -145,7 +145,7 @@ class BaseAgentRunner(AppRunner):
message_tool = PromptMessageTool(
name=tool.tool_name,
description=tool_entity.description.llm,
description=tool_entity.description.llm if tool_entity.description else "",
parameters={
"type": "object",
"properties": {},
@@ -167,7 +167,7 @@ class BaseAgentRunner(AppRunner):
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
enum = [option.value for option in parameter.options] if parameter.options else []
message_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
@@ -187,8 +187,8 @@ class BaseAgentRunner(AppRunner):
convert dataset retriever tool to prompt message tool
"""
prompt_tool = PromptMessageTool(
name=tool.identity.name,
description=tool.description.llm,
name=tool.identity.name if tool.identity else "unknown",
description=tool.description.llm if tool.description else "",
parameters={
"type": "object",
"properties": {},
@@ -210,14 +210,14 @@ class BaseAgentRunner(AppRunner):
return prompt_tool
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
"""
Init tools
"""
tool_instances = {}
prompt_messages_tools = []
for tool in self.app_config.agent.tools if self.app_config.agent else []:
for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception:
@@ -234,7 +234,8 @@ class BaseAgentRunner(AppRunner):
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# save tool entity
tool_instances[dataset_tool.identity.name] = dataset_tool
if dataset_tool.identity is not None:
tool_instances[dataset_tool.identity.name] = dataset_tool
return tool_instances, prompt_messages_tools
@@ -258,7 +259,7 @@ class BaseAgentRunner(AppRunner):
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
enum = [option.value for option in parameter.options] if parameter.options else []
prompt_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
@@ -322,16 +323,21 @@ class BaseAgentRunner(AppRunner):
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, dict],
tool_invoke_meta: Union[str, dict],
observation: Union[str, dict, None],
tool_invoke_meta: Union[str, dict, None],
answer: str,
messages_ids: list[str],
llm_usage: LLMUsage = None,
) -> MessageAgentThought:
llm_usage: LLMUsage | None = None,
):
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
queried_thought = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
)
if not queried_thought:
raise ValueError(f"Agent thought {agent_thought.id} not found")
agent_thought = queried_thought
if thought is not None:
agent_thought.thought = thought
@@ -404,7 +410,7 @@ class BaseAgentRunner(AppRunner):
"""
convert tool variables to db variables
"""
db_variables = (
queried_variables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
@@ -412,6 +418,11 @@ class BaseAgentRunner(AppRunner):
.first()
)
if not queried_variables:
return
db_variables = queried_variables
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
@@ -421,7 +432,7 @@ class BaseAgentRunner(AppRunner):
"""
Organize agent history
"""
result = []
result: list[PromptMessage] = []
# check if there is a system message in the beginning of the conversation
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):