feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, cast
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
@@ -32,9 +32,12 @@ class BuiltinTool(Tool):
:return: the model result
"""
# invoke model
if self.runtime is None or self.identity is None:
raise ValueError("runtime and identity are required")
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
tool_type="builtin",
tool_name=self.identity.name,
prompt_messages=prompt_messages,
@@ -50,8 +53,11 @@ class BuiltinTool(Tool):
:param model_config: the model config
:return: the max tokens
"""
if self.runtime is None:
raise ValueError("runtime is required")
return ModelInvocationUtils.get_max_llm_context_tokens(
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
)
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
@@ -61,7 +67,12 @@ class BuiltinTool(Tool):
:param prompt_messages: the prompt messages
:return: the tokens
"""
return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages)
if self.runtime is None:
raise ValueError("runtime is required")
return ModelInvocationUtils.calculate_tokens(
tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
)
def summary(self, user_id: str, content: str) -> str:
max_tokens = self.get_max_tokens()
@@ -81,7 +92,7 @@ class BuiltinTool(Tool):
stop=[],
)
return summary.message.content
return cast(str, summary.message.content)
lines = content.split("\n")
new_lines = []
@@ -102,16 +113,16 @@ class BuiltinTool(Tool):
# merge lines into messages with max tokens
messages: list[str] = []
for i in new_lines:
for j in new_lines:
if len(messages) == 0:
messages.append(i)
messages.append(j)
else:
if len(messages[-1]) + len(i) < max_tokens * 0.5:
messages[-1] += i
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
messages.append(i)
if len(messages[-1]) + len(j) < max_tokens * 0.5:
messages[-1] += j
if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7:
messages.append(j)
else:
messages[-1] += i
messages[-1] += j
summaries = []
for i in range(len(messages)):