feat: mypy for all type check (#10921)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
@@ -32,9 +32,12 @@ class BuiltinTool(Tool):
|
||||
:return: the model result
|
||||
"""
|
||||
# invoke model
|
||||
if self.runtime is None or self.identity is None:
|
||||
raise ValueError("runtime and identity are required")
|
||||
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tool_type="builtin",
|
||||
tool_name=self.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
@@ -50,8 +53,11 @@ class BuiltinTool(Tool):
|
||||
:param model_config: the model config
|
||||
:return: the max tokens
|
||||
"""
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
)
|
||||
|
||||
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
|
||||
@@ -61,7 +67,12 @@ class BuiltinTool(Tool):
|
||||
:param prompt_messages: the prompt messages
|
||||
:return: the tokens
|
||||
"""
|
||||
return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages)
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
|
||||
return ModelInvocationUtils.calculate_tokens(
|
||||
tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
def summary(self, user_id: str, content: str) -> str:
|
||||
max_tokens = self.get_max_tokens()
|
||||
@@ -81,7 +92,7 @@ class BuiltinTool(Tool):
|
||||
stop=[],
|
||||
)
|
||||
|
||||
return summary.message.content
|
||||
return cast(str, summary.message.content)
|
||||
|
||||
lines = content.split("\n")
|
||||
new_lines = []
|
||||
@@ -102,16 +113,16 @@ class BuiltinTool(Tool):
|
||||
|
||||
# merge lines into messages with max tokens
|
||||
messages: list[str] = []
|
||||
for i in new_lines:
|
||||
for j in new_lines:
|
||||
if len(messages) == 0:
|
||||
messages.append(i)
|
||||
messages.append(j)
|
||||
else:
|
||||
if len(messages[-1]) + len(i) < max_tokens * 0.5:
|
||||
messages[-1] += i
|
||||
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
|
||||
messages.append(i)
|
||||
if len(messages[-1]) + len(j) < max_tokens * 0.5:
|
||||
messages[-1] += j
|
||||
if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7:
|
||||
messages.append(j)
|
||||
else:
|
||||
messages[-1] += i
|
||||
messages[-1] += j
|
||||
|
||||
summaries = []
|
||||
for i in range(len(messages)):
|
||||
|
Reference in New Issue
Block a user