feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Optional, Union
from collections.abc import Generator, Mapping
from typing import Any, Optional
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentScratchpadUnit
@@ -12,6 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
ToolPromptMessage,
UserPromptMessage,
)
@@ -26,18 +27,18 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage] = None
_agent_scratchpad: list[AgentScratchpadUnit] = None
_instruction: str = None
_query: str = None
_prompt_messages_tools: list[PromptMessage] = None
_historic_prompt_messages: list[PromptMessage] | None = None
_agent_scratchpad: list[AgentScratchpadUnit] | None = None
_instruction: str = "" # FIXME this must be str for now
_query: str | None = None
_prompt_messages_tools: list[PromptMessageTool] = []
def run(
self,
message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
inputs: Mapping[str, str],
) -> Generator:
"""
Run Cot agent application
"""
@@ -57,19 +58,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
# convert tools into ModelRuntime Tool format
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
function_call_state = True
llm_usage = {"usage": None}
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = ""
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
@@ -90,7 +91,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# the last iteration, remove all tools
self._prompt_messages_tools = []
message_file_ids = []
message_file_ids: list[str] = []
agent_thought = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
@@ -105,7 +106,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
chunks = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=[],
@@ -115,11 +116,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
callbacks=[],
)
if not isinstance(chunks, Generator):
raise ValueError("Expected streaming response from LLM")
# check llm result
if not chunks:
raise ValueError("failed to invoke llm")
usage_dict = {}
usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",
@@ -139,25 +143,30 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk
# detect action
scratchpad.agent_response += json.dumps(chunk.model_dump())
if scratchpad.agent_response is not None:
scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump())
scratchpad.action = action
else:
scratchpad.agent_response += chunk
scratchpad.thought += chunk
if scratchpad.agent_response is not None:
scratchpad.agent_response += chunk
if scratchpad.thought is not None:
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
)
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
if scratchpad.thought is not None:
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
if self._agent_scratchpad is not None:
self._agent_scratchpad.append(scratchpad)
# get llm usage
if "usage" in usage_dict:
increase_usage(llm_usage, usage_dict["usage"])
if usage_dict["usage"] is not None:
increase_usage(llm_usage, usage_dict["usage"])
else:
usage_dict["usage"] = LLMUsage.empty_usage()
@@ -166,9 +175,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_name=scratchpad.action.action_name if scratchpad.action else "",
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought,
thought=scratchpad.thought or "",
observation="",
answer=scratchpad.agent_response,
answer=scratchpad.agent_response or "",
messages_ids=[],
llm_usage=usage_dict["usage"],
)
@@ -209,7 +218,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought,
thought=scratchpad.thought or "",
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
@@ -247,8 +256,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
answer=final_answer,
messages_ids=[],
)
self.update_db_variables(self.variables_pool, self.db_variables_pool)
if self.variables_pool is not None and self.db_variables_pool is not None:
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
@@ -307,8 +316,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files
for message_file_id, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
if save_as is not None and self.variables_pool:
# FIXME the save_as type is confusing, it should be a string or not
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as))
# publish message file
self.queue_manager.publish(
@@ -325,7 +335,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
"""
fill in inputs from external data tools
"""
@@ -376,11 +386,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit = None
current_scratchpad: AgentScratchpadUnit | None = None
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
if not isinstance(message.content, str | None):
raise NotImplementedError("expected str type")
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
@@ -399,8 +411,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
except:
pass
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
if not current_scratchpad:
continue
if isinstance(message.content, str):
current_scratchpad.observation = message.content
else:
raise NotImplementedError("expected str type")
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))