feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Optional
from collections.abc import Mapping, Sequence
from typing import Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import file_manager
@@ -39,7 +39,7 @@ class AdvancedPromptTransform(PromptTransform):
self,
*,
prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate,
inputs: dict[str, str],
inputs: Mapping[str, str],
query: str,
files: Sequence[File],
context: Optional[str],
@@ -77,7 +77,7 @@ class AdvancedPromptTransform(PromptTransform):
def _get_completion_model_prompt_messages(
self,
prompt_template: CompletionModelPromptTemplate,
inputs: dict,
inputs: Mapping[str, str],
query: Optional[str],
files: Sequence[File],
context: Optional[str],
@@ -90,15 +90,15 @@ class AdvancedPromptTransform(PromptTransform):
"""
raw_prompt = prompt_template.text
prompt_messages = []
prompt_messages: list[PromptMessage] = []
if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs}
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
if memory and memory_config:
if memory and memory_config and memory_config.role_prefix:
role_prefix = memory_config.role_prefix
prompt_inputs = self._set_histories_variable(
memory=memory,
@@ -135,7 +135,7 @@ class AdvancedPromptTransform(PromptTransform):
def _get_chat_model_prompt_messages(
self,
prompt_template: list[ChatModelMessage],
inputs: dict,
inputs: Mapping[str, str],
query: Optional[str],
files: Sequence[File],
context: Optional[str],
@@ -146,7 +146,7 @@ class AdvancedPromptTransform(PromptTransform):
"""
Get chat model prompt messages.
"""
prompt_messages = []
prompt_messages: list[PromptMessage] = []
for prompt_item in prompt_template:
raw_prompt = prompt_item.text
@@ -160,7 +160,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt = vp.convert_template(raw_prompt).text
else:
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs}
prompt_inputs = self._set_context_variable(
context=context, parser=parser, prompt_inputs=prompt_inputs
)
@@ -207,7 +207,7 @@ class AdvancedPromptTransform(PromptTransform):
last_message = prompt_messages[-1] if prompt_messages else None
if last_message and last_message.role == PromptMessageRole.USER:
# get last user message content and add files
prompt_message_contents = [TextPromptMessageContent(data=last_message.content)]
prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))]
for file in files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
@@ -229,7 +229,10 @@ class AdvancedPromptTransform(PromptTransform):
return prompt_messages
def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
def _set_context_variable(
self, context: str | None, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str]
) -> Mapping[str, str]:
prompt_inputs = dict(prompt_inputs)
if "#context#" in parser.variable_keys:
if context:
prompt_inputs["#context#"] = context
@@ -238,7 +241,10 @@ class AdvancedPromptTransform(PromptTransform):
return prompt_inputs
def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
def _set_query_variable(
self, query: str, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str]
) -> Mapping[str, str]:
prompt_inputs = dict(prompt_inputs)
if "#query#" in parser.variable_keys:
if query:
prompt_inputs["#query#"] = query
@@ -254,9 +260,10 @@ class AdvancedPromptTransform(PromptTransform):
raw_prompt: str,
role_prefix: MemoryConfig.RolePrefix,
parser: PromptTemplateParser,
prompt_inputs: dict,
prompt_inputs: Mapping[str, str],
model_config: ModelConfigWithCredentialsEntity,
) -> dict:
) -> Mapping[str, str]:
prompt_inputs = dict(prompt_inputs)
if "#histories#" in parser.variable_keys:
if memory:
inputs = {"#histories#": "", **prompt_inputs}

View File

@@ -31,7 +31,7 @@ class AgentHistoryPromptTransform(PromptTransform):
self.memory = memory
def get_prompt(self) -> list[PromptMessage]:
prompt_messages = []
prompt_messages: list[PromptMessage] = []
num_system = 0
for prompt_message in self.history_messages:
if isinstance(prompt_message, SystemPromptMessage):

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -42,7 +42,7 @@ class PromptTransform:
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
@@ -59,7 +59,7 @@ class PromptTransform:
ai_prefix: Optional[str] = None,
) -> str:
"""Get memory messages."""
kwargs = {"max_token_limit": max_token_limit}
kwargs: dict[str, Any] = {"max_token_limit": max_token_limit}
if human_prefix:
kwargs["human_prefix"] = human_prefix
@@ -76,11 +76,15 @@ class PromptTransform:
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
) -> list[PromptMessage]:
"""Get memory messages."""
return memory.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=memory_config.window.size
if (
memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0
return list(
memory.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=memory_config.window.size
if (
memory_config.window.enabled
and memory_config.window.size is not None
and memory_config.window.size > 0
)
else None,
)
else None,
)

View File

@@ -1,7 +1,8 @@
import enum
import json
import os
from typing import TYPE_CHECKING, Optional
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.app_config.entities import PromptTemplateEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@@ -41,7 +42,7 @@ class ModelMode(enum.StrEnum):
raise ValueError(f"invalid mode value {value}")
prompt_file_contents = {}
prompt_file_contents: dict[str, Any] = {}
class SimplePromptTransform(PromptTransform):
@@ -53,9 +54,9 @@ class SimplePromptTransform(PromptTransform):
self,
app_mode: AppMode,
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
inputs: Mapping[str, str],
query: str,
files: list["File"],
files: Sequence["File"],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
@@ -66,7 +67,7 @@ class SimplePromptTransform(PromptTransform):
if model_mode == ModelMode.CHAT:
prompt_messages, stops = self._get_chat_model_prompt_messages(
app_mode=app_mode,
pre_prompt=prompt_template_entity.simple_prompt_template,
pre_prompt=prompt_template_entity.simple_prompt_template or "",
inputs=inputs,
query=query,
files=files,
@@ -77,7 +78,7 @@ class SimplePromptTransform(PromptTransform):
else:
prompt_messages, stops = self._get_completion_model_prompt_messages(
app_mode=app_mode,
pre_prompt=prompt_template_entity.simple_prompt_template,
pre_prompt=prompt_template_entity.simple_prompt_template or "",
inputs=inputs,
query=query,
files=files,
@@ -171,11 +172,11 @@ class SimplePromptTransform(PromptTransform):
inputs: dict,
query: str,
context: Optional[str],
files: list["File"],
files: Sequence["File"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
prompt_messages = []
prompt_messages: list[PromptMessage] = []
# get prompt
prompt, _ = self.get_prompt_str_and_rules(
@@ -216,7 +217,7 @@ class SimplePromptTransform(PromptTransform):
inputs: dict,
query: str,
context: Optional[str],
files: list["File"],
files: Sequence["File"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
@@ -263,7 +264,7 @@ class SimplePromptTransform(PromptTransform):
return [self.get_last_user_message(prompt, files)], stops
def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage:
def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage:
if files:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
@@ -288,7 +289,7 @@ class SimplePromptTransform(PromptTransform):
# Check if the prompt file is already loaded
if prompt_file_name in prompt_file_contents:
return prompt_file_contents[prompt_file_name]
return cast(dict, prompt_file_contents[prompt_file_name])
# Get the absolute path of the subdirectory
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")
@@ -301,7 +302,7 @@ class SimplePromptTransform(PromptTransform):
# Store the content of the prompt file
prompt_file_contents[prompt_file_name] = content
return content
return cast(dict, content)
def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
# baichuan

View File

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import cast
from typing import Any, cast
from core.model_runtime.entities import (
AssistantPromptMessage,
@@ -72,7 +72,7 @@ class PromptMessageUtil:
}
)
else:
text = prompt_message.content
text = cast(str, prompt_message.content)
prompt = {"role": role, "text": text, "files": files}
@@ -99,9 +99,9 @@ class PromptMessageUtil:
}
)
else:
text = prompt_message.content
text = cast(str, prompt_message.content)
params = {
params: dict[str, Any] = {
"role": "user",
"text": text,
}

View File

@@ -1,4 +1,5 @@
import re
from collections.abc import Mapping
REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#histories#|#query#|#context#)\}\}")
WITH_VARIABLE_TMPL_REGEX = re.compile(
@@ -28,7 +29,7 @@ class PromptTemplateParser:
# Regular expression to match the template rules
return re.findall(self.regex, self.template)
def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
def format(self, inputs: Mapping[str, str], remove_template_variables: bool = True) -> str:
def replacer(match):
key = match.group(1)
value = inputs.get(key, match.group(0)) # return original matched string if key not found