feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -88,8 +88,8 @@ class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]:
node_inputs = None
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
node_inputs: Optional[dict[str, Any]] = None
process_data = None
try:
@@ -196,7 +196,6 @@ class LLMNode(BaseNode[LLMNodeData]):
error_type=type(e).__name__,
)
)
return
except Exception as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
@@ -206,7 +205,6 @@ class LLMNode(BaseNode[LLMNodeData]):
process_data=process_data,
)
)
return
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
@@ -302,7 +300,7 @@ class LLMNode(BaseNode[LLMNodeData]):
return messages
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables = {}
variables: dict[str, Any] = {}
if not node_data.prompt_config:
return variables
@@ -319,7 +317,7 @@ class LLMNode(BaseNode[LLMNodeData]):
"""
# check if it's a context structure
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
return input_dict["content"]
return str(input_dict["content"])
# else, parse the dict
try:
@@ -557,7 +555,8 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages = []
# FIXME: fix the type error cause prompt_messages is type quick a few times
prompt_messages: list[Any] = []
if isinstance(prompt_template, list):
# For chat model
@@ -783,7 +782,7 @@ class LLMNode(BaseNode[LLMNodeData]):
else:
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
variable_mapping = {}
variable_mapping: dict[str, Any] = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
@@ -981,7 +980,7 @@ def _handle_memory_chat_mode(
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
memory_messages = []
memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)