feat: mypy for all type check (#10921)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import file_manager
|
||||
@@ -39,7 +39,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
self,
|
||||
*,
|
||||
prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate,
|
||||
inputs: dict[str, str],
|
||||
inputs: Mapping[str, str],
|
||||
query: str,
|
||||
files: Sequence[File],
|
||||
context: Optional[str],
|
||||
@@ -77,7 +77,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
def _get_completion_model_prompt_messages(
|
||||
self,
|
||||
prompt_template: CompletionModelPromptTemplate,
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, str],
|
||||
query: Optional[str],
|
||||
files: Sequence[File],
|
||||
context: Optional[str],
|
||||
@@ -90,15 +90,15 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
"""
|
||||
raw_prompt = prompt_template.text
|
||||
|
||||
prompt_messages = []
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
|
||||
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
|
||||
|
||||
if memory and memory_config:
|
||||
if memory and memory_config and memory_config.role_prefix:
|
||||
role_prefix = memory_config.role_prefix
|
||||
prompt_inputs = self._set_histories_variable(
|
||||
memory=memory,
|
||||
@@ -135,7 +135,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
def _get_chat_model_prompt_messages(
|
||||
self,
|
||||
prompt_template: list[ChatModelMessage],
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, str],
|
||||
query: Optional[str],
|
||||
files: Sequence[File],
|
||||
context: Optional[str],
|
||||
@@ -146,7 +146,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
"""
|
||||
Get chat model prompt messages.
|
||||
"""
|
||||
prompt_messages = []
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for prompt_item in prompt_template:
|
||||
raw_prompt = prompt_item.text
|
||||
|
||||
@@ -160,7 +160,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt = vp.convert_template(raw_prompt).text
|
||||
else:
|
||||
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
prompt_inputs = self._set_context_variable(
|
||||
context=context, parser=parser, prompt_inputs=prompt_inputs
|
||||
)
|
||||
@@ -207,7 +207,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
last_message = prompt_messages[-1] if prompt_messages else None
|
||||
if last_message and last_message.role == PromptMessageRole.USER:
|
||||
# get last user message content and add files
|
||||
prompt_message_contents = [TextPromptMessageContent(data=last_message.content)]
|
||||
prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))]
|
||||
for file in files:
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||
|
||||
@@ -229,7 +229,10 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
||||
def _set_context_variable(
|
||||
self, context: str | None, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str]
|
||||
) -> Mapping[str, str]:
|
||||
prompt_inputs = dict(prompt_inputs)
|
||||
if "#context#" in parser.variable_keys:
|
||||
if context:
|
||||
prompt_inputs["#context#"] = context
|
||||
@@ -238,7 +241,10 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
return prompt_inputs
|
||||
|
||||
def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
||||
def _set_query_variable(
|
||||
self, query: str, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str]
|
||||
) -> Mapping[str, str]:
|
||||
prompt_inputs = dict(prompt_inputs)
|
||||
if "#query#" in parser.variable_keys:
|
||||
if query:
|
||||
prompt_inputs["#query#"] = query
|
||||
@@ -254,9 +260,10 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
raw_prompt: str,
|
||||
role_prefix: MemoryConfig.RolePrefix,
|
||||
parser: PromptTemplateParser,
|
||||
prompt_inputs: dict,
|
||||
prompt_inputs: Mapping[str, str],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> dict:
|
||||
) -> Mapping[str, str]:
|
||||
prompt_inputs = dict(prompt_inputs)
|
||||
if "#histories#" in parser.variable_keys:
|
||||
if memory:
|
||||
inputs = {"#histories#": "", **prompt_inputs}
|
||||
|
@@ -31,7 +31,7 @@ class AgentHistoryPromptTransform(PromptTransform):
|
||||
self.memory = memory
|
||||
|
||||
def get_prompt(self) -> list[PromptMessage]:
|
||||
prompt_messages = []
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
num_system = 0
|
||||
for prompt_message in self.history_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@@ -42,7 +42,7 @@ class PromptTransform:
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
@@ -59,7 +59,7 @@ class PromptTransform:
|
||||
ai_prefix: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Get memory messages."""
|
||||
kwargs = {"max_token_limit": max_token_limit}
|
||||
kwargs: dict[str, Any] = {"max_token_limit": max_token_limit}
|
||||
|
||||
if human_prefix:
|
||||
kwargs["human_prefix"] = human_prefix
|
||||
@@ -76,11 +76,15 @@ class PromptTransform:
|
||||
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
) -> list[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
return memory.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=memory_config.window.size
|
||||
if (
|
||||
memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0
|
||||
return list(
|
||||
memory.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=memory_config.window.size
|
||||
if (
|
||||
memory_config.window.enabled
|
||||
and memory_config.window.size is not None
|
||||
and memory_config.window.size > 0
|
||||
)
|
||||
else None,
|
||||
)
|
||||
else None,
|
||||
)
|
||||
|
@@ -1,7 +1,8 @@
|
||||
import enum
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from core.app.app_config.entities import PromptTemplateEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
@@ -41,7 +42,7 @@ class ModelMode(enum.StrEnum):
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
prompt_file_contents = {}
|
||||
prompt_file_contents: dict[str, Any] = {}
|
||||
|
||||
|
||||
class SimplePromptTransform(PromptTransform):
|
||||
@@ -53,9 +54,9 @@ class SimplePromptTransform(PromptTransform):
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, str],
|
||||
query: str,
|
||||
files: list["File"],
|
||||
files: Sequence["File"],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
@@ -66,7 +67,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
if model_mode == ModelMode.CHAT:
|
||||
prompt_messages, stops = self._get_chat_model_prompt_messages(
|
||||
app_mode=app_mode,
|
||||
pre_prompt=prompt_template_entity.simple_prompt_template,
|
||||
pre_prompt=prompt_template_entity.simple_prompt_template or "",
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
@@ -77,7 +78,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
else:
|
||||
prompt_messages, stops = self._get_completion_model_prompt_messages(
|
||||
app_mode=app_mode,
|
||||
pre_prompt=prompt_template_entity.simple_prompt_template,
|
||||
pre_prompt=prompt_template_entity.simple_prompt_template or "",
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
@@ -171,11 +172,11 @@ class SimplePromptTransform(PromptTransform):
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list["File"],
|
||||
files: Sequence["File"],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
prompt_messages = []
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
# get prompt
|
||||
prompt, _ = self.get_prompt_str_and_rules(
|
||||
@@ -216,7 +217,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list["File"],
|
||||
files: Sequence["File"],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
@@ -263,7 +264,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
|
||||
return [self.get_last_user_message(prompt, files)], stops
|
||||
|
||||
def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage:
|
||||
def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage:
|
||||
if files:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||
@@ -288,7 +289,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
|
||||
# Check if the prompt file is already loaded
|
||||
if prompt_file_name in prompt_file_contents:
|
||||
return prompt_file_contents[prompt_file_name]
|
||||
return cast(dict, prompt_file_contents[prompt_file_name])
|
||||
|
||||
# Get the absolute path of the subdirectory
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")
|
||||
@@ -301,7 +302,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
# Store the content of the prompt file
|
||||
prompt_file_contents[prompt_file_name] = content
|
||||
|
||||
return content
|
||||
return cast(dict, content)
|
||||
|
||||
def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
|
||||
# baichuan
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
@@ -72,7 +72,7 @@ class PromptMessageUtil:
|
||||
}
|
||||
)
|
||||
else:
|
||||
text = prompt_message.content
|
||||
text = cast(str, prompt_message.content)
|
||||
|
||||
prompt = {"role": role, "text": text, "files": files}
|
||||
|
||||
@@ -99,9 +99,9 @@ class PromptMessageUtil:
|
||||
}
|
||||
)
|
||||
else:
|
||||
text = prompt_message.content
|
||||
text = cast(str, prompt_message.content)
|
||||
|
||||
params = {
|
||||
params: dict[str, Any] = {
|
||||
"role": "user",
|
||||
"text": text,
|
||||
}
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
|
||||
REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#histories#|#query#|#context#)\}\}")
|
||||
WITH_VARIABLE_TMPL_REGEX = re.compile(
|
||||
@@ -28,7 +29,7 @@ class PromptTemplateParser:
|
||||
# Regular expression to match the template rules
|
||||
return re.findall(self.regex, self.template)
|
||||
|
||||
def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
|
||||
def format(self, inputs: Mapping[str, str], remove_template_variables: bool = True) -> str:
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
value = inputs.get(key, match.group(0)) # return original matched string if key not found
|
||||
|
Reference in New Issue
Block a user