feat: mypy for all type check (#10921)
This commit is contained in:
@@ -63,7 +63,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
# generate stream outputs
|
||||
yield from self._generate_stream_outputs_when_node_finished(event)
|
||||
yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event))
|
||||
else:
|
||||
yield event
|
||||
|
||||
@@ -130,7 +130,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=text,
|
||||
from_variable_selector=value_selector,
|
||||
from_variable_selector=list(value_selector),
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
|
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -19,7 +19,7 @@ class StreamProcessor(ABC):
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
|
||||
finished_node_id = event.route_node_state.node_id
|
||||
if finished_node_id not in self.rest_node_ids:
|
||||
return
|
||||
@@ -32,8 +32,8 @@ class StreamProcessor(ABC):
|
||||
return
|
||||
|
||||
if run_result.edge_source_handle:
|
||||
reachable_node_ids = []
|
||||
unreachable_first_node_ids = []
|
||||
reachable_node_ids: list[str] = []
|
||||
unreachable_first_node_ids: list[str] = []
|
||||
if finished_node_id not in self.graph.edge_mapping:
|
||||
logger.warning(f"node {finished_node_id} has no edge mapping")
|
||||
return
|
||||
|
@@ -38,7 +38,8 @@ class DefaultValue(BaseModel):
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
|
||||
"""Unified array type validation"""
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
@@ -84,7 +85,7 @@ class DefaultValue(BaseModel):
|
||||
},
|
||||
}
|
||||
|
||||
validator = type_validators.get(self.type)
|
||||
validator: dict[str, Any] = type_validators.get(self.type, {})
|
||||
if not validator:
|
||||
if self.type == DefaultValueType.ARRAY_FILES:
|
||||
# Handle files type
|
||||
|
@@ -125,7 +125,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
transformed_result = {}
|
||||
transformed_result: dict[str, Any] = {}
|
||||
if output_schema is None:
|
||||
# validate output thought instance type
|
||||
for output_name, output_value in result.items():
|
||||
|
@@ -14,7 +14,7 @@ class CodeNodeData(BaseNodeData):
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
children: Optional[dict[str, "Output"]] = None
|
||||
children: Optional[dict[str, "CodeNodeData.Output"]] = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
name: str
|
||||
|
@@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import cast
|
||||
|
||||
import docx
|
||||
import pandas as pd
|
||||
@@ -159,7 +160,7 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
|
||||
"""Extract the content from yaml file"""
|
||||
try:
|
||||
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore"))
|
||||
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
|
||||
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
|
||||
except (UnicodeDecodeError, yaml.YAMLError) as e:
|
||||
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
|
||||
|
||||
@@ -229,9 +230,9 @@ def _download_file_content(file: File) -> bytes:
|
||||
raise FileDownloadError("Missing URL for remote file")
|
||||
response = ssrf_proxy.get(file.remote_url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
return cast(bytes, response.content)
|
||||
else:
|
||||
return file_manager.download(file)
|
||||
return cast(bytes, file_manager.download(file))
|
||||
except Exception as e:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
|
@@ -67,7 +67,7 @@ class EndStreamGeneratorRouter:
|
||||
and node_type == NodeType.LLM.value
|
||||
and variable_selector.value_selector[1] == "text"
|
||||
):
|
||||
value_selectors.append(variable_selector.value_selector)
|
||||
value_selectors.append(list(variable_selector.value_selector))
|
||||
|
||||
return value_selectors
|
||||
|
||||
@@ -119,8 +119,7 @@ class EndStreamGeneratorRouter:
|
||||
current_node_id: str,
|
||||
end_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]],
|
||||
# type: ignore[name-defined]
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
end_dependencies: dict[str, list[str]],
|
||||
) -> None:
|
||||
"""
|
||||
|
@@ -23,7 +23,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
self.route_position[end_node_id] = 0
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
self.has_output = False
|
||||
self.output_node_ids = set()
|
||||
self.output_node_ids: set[str] = set()
|
||||
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
|
@@ -42,6 +42,6 @@ class RunRetryEvent(BaseModel):
|
||||
class SingleStepRetryEvent(NodeRunResult):
|
||||
"""Single step retry event"""
|
||||
|
||||
status: str = WorkflowNodeExecutionStatus.RETRY.value
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY
|
||||
|
||||
elapsed_time: float = Field(..., description="elapsed time")
|
||||
|
@@ -107,9 +107,9 @@ class Executor:
|
||||
if not (key := key.strip()):
|
||||
continue
|
||||
|
||||
value = value[0].strip() if value else ""
|
||||
value_str = value[0].strip() if value else ""
|
||||
result.append(
|
||||
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value).text)
|
||||
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text)
|
||||
)
|
||||
|
||||
self.params = result
|
||||
@@ -182,9 +182,10 @@ class Executor:
|
||||
self.variable_pool.convert_template(item.key).text: item.file
|
||||
for item in filter(lambda item: item.type == "file", data)
|
||||
}
|
||||
files: dict[str, Any] = {}
|
||||
files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()}
|
||||
files = {k: v for k, v in files.items() if v is not None}
|
||||
files = {k: variable.value for k, variable in files.items()}
|
||||
files = {k: variable.value for k, variable in files.items() if variable is not None}
|
||||
files = {
|
||||
k: (v.filename, file_manager.download(v), v.mime_type or "application/octet-stream")
|
||||
for k, v in files.items()
|
||||
@@ -258,7 +259,8 @@ class Executor:
|
||||
response = getattr(ssrf_proxy, self.method)(**request_args)
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
raise HttpRequestNodeError(str(e))
|
||||
return response
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
return response # type: ignore
|
||||
|
||||
def invoke(self) -> Response:
|
||||
# assemble headers
|
||||
@@ -300,37 +302,37 @@ class Executor:
|
||||
continue
|
||||
raw += f"{k}: {v}\r\n"
|
||||
|
||||
body = ""
|
||||
body_string = ""
|
||||
if self.files:
|
||||
for k, v in self.files.items():
|
||||
body += f"--{boundary}\r\n"
|
||||
body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
|
||||
body += f"{v[1]}\r\n"
|
||||
body += f"--{boundary}--\r\n"
|
||||
body_string += f"--{boundary}\r\n"
|
||||
body_string += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
|
||||
body_string += f"{v[1]}\r\n"
|
||||
body_string += f"--{boundary}--\r\n"
|
||||
elif self.node_data.body:
|
||||
if self.content:
|
||||
if isinstance(self.content, str):
|
||||
body = self.content
|
||||
body_string = self.content
|
||||
elif isinstance(self.content, bytes):
|
||||
body = self.content.decode("utf-8", errors="replace")
|
||||
body_string = self.content.decode("utf-8", errors="replace")
|
||||
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
|
||||
body = urlencode(self.data)
|
||||
body_string = urlencode(self.data)
|
||||
elif self.data and self.node_data.body.type == "form-data":
|
||||
for key, value in self.data.items():
|
||||
body += f"--{boundary}\r\n"
|
||||
body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
body += f"{value}\r\n"
|
||||
body += f"--{boundary}--\r\n"
|
||||
body_string += f"--{boundary}\r\n"
|
||||
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
body_string += f"{value}\r\n"
|
||||
body_string += f"--{boundary}--\r\n"
|
||||
elif self.json:
|
||||
body = json.dumps(self.json)
|
||||
body_string = json.dumps(self.json)
|
||||
elif self.node_data.body.type == "raw-text":
|
||||
if len(self.node_data.body.data) != 1:
|
||||
raise RequestBodyError("raw-text body type should have exactly one item")
|
||||
body = self.node_data.body.data[0].value
|
||||
if body:
|
||||
raw += f"Content-Length: {len(body)}\r\n"
|
||||
body_string = self.node_data.body.data[0].value
|
||||
if body_string:
|
||||
raw += f"Content-Length: {len(body_string)}\r\n"
|
||||
raw += "\r\n" # Empty line between headers and body
|
||||
raw += body
|
||||
raw += body_string
|
||||
|
||||
return raw
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod
|
||||
@@ -36,7 +36,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
_node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: dict | None = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
|
||||
return {
|
||||
"type": "http-request",
|
||||
"config": {
|
||||
@@ -160,8 +160,8 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
)
|
||||
|
||||
mapping = {}
|
||||
for selector in selectors:
|
||||
mapping[node_id + "." + selector.variable] = selector.value_selector
|
||||
for selector_iter in selectors:
|
||||
mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector
|
||||
|
||||
return mapping
|
||||
|
||||
|
@@ -361,13 +361,16 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||
if self.node_data.is_parallel:
|
||||
metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
|
||||
else:
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index
|
||||
metadata = {
|
||||
**metadata,
|
||||
NodeRunMetadataKey.ITERATION_ID: self.node_id,
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID
|
||||
if self.node_data.is_parallel
|
||||
else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id
|
||||
if self.node_data.is_parallel
|
||||
else iter_run_index,
|
||||
}
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
return event
|
||||
|
||||
|
@@ -147,6 +147,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
planning_strategy=planning_strategy,
|
||||
)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
reranking_model = {
|
||||
@@ -157,6 +159,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
reranking_model = None
|
||||
weights = None
|
||||
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
|
||||
if node_data.multiple_retrieval_config.weights is None:
|
||||
raise ValueError("weights is required")
|
||||
reranking_model = None
|
||||
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
|
||||
weights = {
|
||||
@@ -180,7 +184,9 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
available_datasets=available_datasets,
|
||||
query=query,
|
||||
top_k=node_data.multiple_retrieval_config.top_k,
|
||||
score_threshold=node_data.multiple_retrieval_config.score_threshold,
|
||||
score_threshold=node_data.multiple_retrieval_config.score_threshold
|
||||
if node_data.multiple_retrieval_config.score_threshold is not None
|
||||
else 0.0,
|
||||
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model=reranking_model,
|
||||
weights=weights,
|
||||
@@ -205,7 +211,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
"content": item.page_content,
|
||||
}
|
||||
retrieval_resource_list.append(source)
|
||||
document_score_list = {}
|
||||
document_score_list: dict[str, float] = {}
|
||||
# deal with dify documents
|
||||
if dify_documents:
|
||||
document_score_list = {}
|
||||
@@ -260,7 +266,9 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
retrieval_resource_list.append(source)
|
||||
if retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list, key=lambda x: x.get("metadata").get("score") or 0.0, reverse=True
|
||||
retrieval_resource_list,
|
||||
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
|
||||
reverse=True,
|
||||
)
|
||||
position = 1
|
||||
for item in retrieval_resource_list:
|
||||
@@ -295,6 +303,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
if node_data.single_retrieval_config is None:
|
||||
raise ValueError("single_retrieval_config is required")
|
||||
model_name = node_data.single_retrieval_config.model.name
|
||||
provider_name = node_data.single_retrieval_config.model.provider
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Literal, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
@@ -17,9 +17,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
_node_type = NodeType.LIST_OPERATOR
|
||||
|
||||
def _run(self):
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
inputs: dict[str, list] = {}
|
||||
process_data: dict[str, list] = {}
|
||||
outputs: dict[str, Any] = {}
|
||||
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
|
||||
if variable is None:
|
||||
@@ -93,6 +93,8 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
def _apply_filter(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
filter_func: Callable[[Any], bool]
|
||||
result: list[Any] = []
|
||||
for condition in self.node_data.filter_by.conditions:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
if not isinstance(condition.value, str):
|
||||
@@ -236,6 +238,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
|
||||
|
||||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
extract_func: Callable[[File], Any]
|
||||
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
@@ -249,47 +252,47 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
def _contains(value: str):
|
||||
def _contains(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: value in x
|
||||
|
||||
|
||||
def _startswith(value: str):
|
||||
def _startswith(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: x.startswith(value)
|
||||
|
||||
|
||||
def _endswith(value: str):
|
||||
def _endswith(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: x.endswith(value)
|
||||
|
||||
|
||||
def _is(value: str):
|
||||
def _is(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: x is value
|
||||
|
||||
|
||||
def _in(value: str | Sequence[str]):
|
||||
def _in(value: str | Sequence[str]) -> Callable[[str], bool]:
|
||||
return lambda x: x in value
|
||||
|
||||
|
||||
def _eq(value: int | float):
|
||||
def _eq(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x == value
|
||||
|
||||
|
||||
def _ne(value: int | float):
|
||||
def _ne(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x != value
|
||||
|
||||
|
||||
def _lt(value: int | float):
|
||||
def _lt(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x < value
|
||||
|
||||
|
||||
def _le(value: int | float):
|
||||
def _le(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x <= value
|
||||
|
||||
|
||||
def _gt(value: int | float):
|
||||
def _gt(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x > value
|
||||
|
||||
|
||||
def _ge(value: int | float):
|
||||
def _ge(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x >= value
|
||||
|
||||
|
||||
@@ -302,6 +305,7 @@ def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
|
||||
|
||||
|
||||
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
|
||||
extract_func: Callable[[File], Any]
|
||||
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
|
||||
extract_func = _get_file_extract_string_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
|
@@ -88,8 +88,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
_node_data_cls = LLMNodeData
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
node_inputs = None
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
node_inputs: Optional[dict[str, Any]] = None
|
||||
process_data = None
|
||||
|
||||
try:
|
||||
@@ -196,7 +196,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
@@ -206,7 +205,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
process_data=process_data,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
|
||||
@@ -302,7 +300,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
return messages
|
||||
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
||||
variables = {}
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
if not node_data.prompt_config:
|
||||
return variables
|
||||
@@ -319,7 +317,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
"""
|
||||
# check if it's a context structure
|
||||
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
|
||||
return input_dict["content"]
|
||||
return str(input_dict["content"])
|
||||
|
||||
# else, parse the dict
|
||||
try:
|
||||
@@ -557,7 +555,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||
prompt_messages = []
|
||||
# FIXME: fix the type error cause prompt_messages is type quick a few times
|
||||
prompt_messages: list[Any] = []
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
@@ -783,7 +782,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
else:
|
||||
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
|
||||
|
||||
variable_mapping = {}
|
||||
variable_mapping: dict[str, Any] = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
@@ -981,7 +980,7 @@ def _handle_memory_chat_mode(
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> Sequence[PromptMessage]:
|
||||
memory_messages = []
|
||||
memory_messages: Sequence[PromptMessage] = []
|
||||
# Get messages from memory for chat model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||
|
@@ -14,8 +14,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
def _run(self) -> LoopState:
|
||||
return super()._run()
|
||||
def _run(self) -> LoopState: # type: ignore
|
||||
return super()._run() # type: ignore
|
||||
|
||||
@classmethod
|
||||
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
|
||||
@@ -28,7 +28,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
|
||||
# TODO waiting for implementation
|
||||
return [
|
||||
Condition(
|
||||
Condition( # type: ignore
|
||||
variable_selector=[node_id, "index"],
|
||||
comparison_operator="≤",
|
||||
value_type="value_selector",
|
||||
|
@@ -25,7 +25,7 @@ class ParameterConfig(BaseModel):
|
||||
raise ValueError("Parameter name is required")
|
||||
if value in {"__reason", "__is_success"}:
|
||||
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
@@ -52,7 +52,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
|
||||
:return: parameter json schema
|
||||
"""
|
||||
parameters = {"type": "object", "properties": {}, "required": []}
|
||||
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
for parameter in self.parameters:
|
||||
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
||||
|
@@ -63,7 +63,8 @@ class ParameterExtractorNode(LLMNode):
|
||||
Parameter Extractor Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = ParameterExtractorNodeData
|
||||
# FIXME: figure out why here is different from super class
|
||||
_node_data_cls = ParameterExtractorNodeData # type: ignore
|
||||
_node_type = NodeType.PARAMETER_EXTRACTOR
|
||||
|
||||
_model_instance: Optional[ModelInstance] = None
|
||||
@@ -253,6 +254,9 @@ class ParameterExtractorNode(LLMNode):
|
||||
# deduct quota
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
if text is None:
|
||||
text = ""
|
||||
|
||||
return text, usage, tool_call
|
||||
|
||||
def _generate_function_call_prompt(
|
||||
@@ -605,9 +609,10 @@ class ParameterExtractorNode(LLMNode):
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
return cast(dict, json.loads(json_str))
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
|
||||
"""
|
||||
@@ -616,13 +621,13 @@ class ParameterExtractorNode(LLMNode):
|
||||
if not tool_call or not tool_call.function.arguments:
|
||||
return None
|
||||
|
||||
return json.loads(tool_call.function.arguments)
|
||||
return cast(dict, json.loads(tool_call.function.arguments))
|
||||
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
|
||||
"""
|
||||
Generate default result.
|
||||
"""
|
||||
result = {}
|
||||
result: dict[str, Any] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.type == "number":
|
||||
result[parameter.name] = 0
|
||||
@@ -772,7 +777,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
node_data: ParameterExtractorNodeData, # type: ignore
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -781,6 +786,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# FIXME: fix the type error later
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
|
||||
|
||||
if node_data.instruction:
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters"
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
|
||||
@@ -35,7 +37,7 @@ FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information fr
|
||||
</structure>
|
||||
""" # noqa: E501
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [
|
||||
FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [
|
||||
{
|
||||
"user": {
|
||||
"query": "What is the weather today in SF?",
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@@ -34,12 +34,9 @@ from .template_prompts import (
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_3,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file import File
|
||||
|
||||
|
||||
class QuestionClassifierNode(LLMNode):
|
||||
_node_data_cls = QuestionClassifierNodeData
|
||||
_node_data_cls = QuestionClassifierNodeData # type: ignore
|
||||
_node_type = NodeType.QUESTION_CLASSIFIER
|
||||
|
||||
def _run(self):
|
||||
@@ -61,7 +58,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
node_data.instruction = node_data.instruction or ""
|
||||
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
|
||||
|
||||
files: Sequence[File] = (
|
||||
files = (
|
||||
self._fetch_files(
|
||||
selector=node_data.vision.configs.variable_selector,
|
||||
)
|
||||
@@ -168,7 +165,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
node_data: Any,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -177,6 +174,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(QuestionClassifierNodeData, node_data)
|
||||
variable_mapping = {"query": node_data.query_variable_selector}
|
||||
variable_selectors = []
|
||||
if node_data.instruction:
|
||||
|
@@ -9,7 +9,6 @@ from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCal
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@@ -46,6 +45,8 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
|
||||
# get tool runtime
|
||||
try:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
|
||||
)
|
||||
@@ -142,7 +143,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
"""
|
||||
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
|
||||
|
||||
result = {}
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
parameter = tool_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
@@ -264,9 +265,9 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
"""
|
||||
return "\n".join(
|
||||
[
|
||||
f"{message.message}"
|
||||
str(message.message)
|
||||
if message.type == ToolInvokeMessage.MessageType.TEXT
|
||||
else f"Link: {message.message}"
|
||||
else f"Link: {str(message.message)}"
|
||||
for message in tool_response
|
||||
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}
|
||||
]
|
||||
|
@@ -36,6 +36,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = get_zero_value(original_variable.value_type)
|
||||
if income_value is None:
|
||||
raise VariableOperatorNodeError("income value not found")
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
|
||||
case _:
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
@@ -29,7 +29,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
inputs = self.node_data.model_dump()
|
||||
process_data = {}
|
||||
process_data: dict[str, Any] = {}
|
||||
# NOTE: This node has no outputs
|
||||
updated_variables: list[Variable] = []
|
||||
|
||||
@@ -119,7 +119,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
else:
|
||||
conversation_id = conversation_id.value
|
||||
common_helpers.update_conversation_variable(
|
||||
conversation_id=conversation_id,
|
||||
conversation_id=cast(str, conversation_id),
|
||||
variable=variable,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user