feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -63,7 +63,7 @@ class AnswerStreamProcessor(StreamProcessor):
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(event)
yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event))
else:
yield event
@@ -130,7 +130,7 @@ class AnswerStreamProcessor(StreamProcessor):
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=value_selector,
from_variable_selector=list(value_selector),
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,

View File

@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
logger = logging.getLogger(__name__)
@@ -19,7 +19,7 @@ class StreamProcessor(ABC):
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
@@ -32,8 +32,8 @@ class StreamProcessor(ABC):
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
reachable_node_ids: list[str] = []
unreachable_first_node_ids: list[str] = []
if finished_node_id not in self.graph.edge_mapping:
logger.warning(f"node {finished_node_id} has no edge mapping")
return

View File

@@ -38,7 +38,8 @@ class DefaultValue(BaseModel):
@staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
@staticmethod
def _convert_number(value: str) -> float:
@@ -84,7 +85,7 @@ class DefaultValue(BaseModel):
},
}
validator = type_validators.get(self.type)
validator: dict[str, Any] = type_validators.get(self.type, {})
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type

View File

@@ -125,7 +125,7 @@ class CodeNode(BaseNode[CodeNodeData]):
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
transformed_result = {}
transformed_result: dict[str, Any] = {}
if output_schema is None:
# validate output thought instance type
for output_name, output_value in result.items():

View File

@@ -14,7 +14,7 @@ class CodeNodeData(BaseNodeData):
class Output(BaseModel):
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
children: Optional[dict[str, "Output"]] = None
children: Optional[dict[str, "CodeNodeData.Output"]] = None
class Dependency(BaseModel):
name: str

View File

@@ -4,6 +4,7 @@ import json
import logging
import os
import tempfile
from typing import cast
import docx
import pandas as pd
@@ -159,7 +160,7 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
"""Extract the content from yaml file"""
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore"))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
except (UnicodeDecodeError, yaml.YAMLError) as e:
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
@@ -229,9 +230,9 @@ def _download_file_content(file: File) -> bytes:
raise FileDownloadError("Missing URL for remote file")
response = ssrf_proxy.get(file.remote_url)
response.raise_for_status()
return response.content
return cast(bytes, response.content)
else:
return file_manager.download(file)
return cast(bytes, file_manager.download(file))
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e

View File

@@ -67,7 +67,7 @@ class EndStreamGeneratorRouter:
and node_type == NodeType.LLM.value
and variable_selector.value_selector[1] == "text"
):
value_selectors.append(variable_selector.value_selector)
value_selectors.append(list(variable_selector.value_selector))
return value_selectors
@@ -119,8 +119,7 @@ class EndStreamGeneratorRouter:
current_node_id: str,
end_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]],
# type: ignore[name-defined]
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
end_dependencies: dict[str, list[str]],
) -> None:
"""

View File

@@ -23,7 +23,7 @@ class EndStreamProcessor(StreamProcessor):
self.route_position[end_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
self.has_output = False
self.output_node_ids = set()
self.output_node_ids: set[str] = set()
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
for event in generator:

View File

@@ -42,6 +42,6 @@ class RunRetryEvent(BaseModel):
class SingleStepRetryEvent(NodeRunResult):
"""Single step retry event"""
status: str = WorkflowNodeExecutionStatus.RETRY.value
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY
elapsed_time: float = Field(..., description="elapsed time")

View File

@@ -107,9 +107,9 @@ class Executor:
if not (key := key.strip()):
continue
value = value[0].strip() if value else ""
value_str = value[0].strip() if value else ""
result.append(
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value).text)
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text)
)
self.params = result
@@ -182,9 +182,10 @@ class Executor:
self.variable_pool.convert_template(item.key).text: item.file
for item in filter(lambda item: item.type == "file", data)
}
files: dict[str, Any] = {}
files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()}
files = {k: v for k, v in files.items() if v is not None}
files = {k: variable.value for k, variable in files.items()}
files = {k: variable.value for k, variable in files.items() if variable is not None}
files = {
k: (v.filename, file_manager.download(v), v.mime_type or "application/octet-stream")
for k, v in files.items()
@@ -258,7 +259,8 @@ class Executor:
response = getattr(ssrf_proxy, self.method)(**request_args)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e))
return response
# FIXME: fix type ignore, this maybe httpx type issue
return response # type: ignore
def invoke(self) -> Response:
# assemble headers
@@ -300,37 +302,37 @@ class Executor:
continue
raw += f"{k}: {v}\r\n"
body = ""
body_string = ""
if self.files:
for k, v in self.files.items():
body += f"--{boundary}\r\n"
body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
body += f"{v[1]}\r\n"
body += f"--{boundary}--\r\n"
body_string += f"--{boundary}\r\n"
body_string += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
body_string += f"{v[1]}\r\n"
body_string += f"--{boundary}--\r\n"
elif self.node_data.body:
if self.content:
if isinstance(self.content, str):
body = self.content
body_string = self.content
elif isinstance(self.content, bytes):
body = self.content.decode("utf-8", errors="replace")
body_string = self.content.decode("utf-8", errors="replace")
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
body = urlencode(self.data)
body_string = urlencode(self.data)
elif self.data and self.node_data.body.type == "form-data":
for key, value in self.data.items():
body += f"--{boundary}\r\n"
body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
body += f"{value}\r\n"
body += f"--{boundary}--\r\n"
body_string += f"--{boundary}\r\n"
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
body_string += f"{value}\r\n"
body_string += f"--{boundary}--\r\n"
elif self.json:
body = json.dumps(self.json)
body_string = json.dumps(self.json)
elif self.node_data.body.type == "raw-text":
if len(self.node_data.body.data) != 1:
raise RequestBodyError("raw-text body type should have exactly one item")
body = self.node_data.body.data[0].value
if body:
raw += f"Content-Length: {len(body)}\r\n"
body_string = self.node_data.body.data[0].value
if body_string:
raw += f"Content-Length: {len(body_string)}\r\n"
raw += "\r\n" # Empty line between headers and body
raw += body
raw += body_string
return raw

View File

@@ -1,7 +1,7 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
from typing import Any
from typing import Any, Optional
from configs import dify_config
from core.file import File, FileTransferMethod
@@ -36,7 +36,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
_node_type = NodeType.HTTP_REQUEST
@classmethod
def get_default_config(cls, filters: dict | None = None) -> dict:
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
return {
"type": "http-request",
"config": {
@@ -160,8 +160,8 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
)
mapping = {}
for selector in selectors:
mapping[node_id + "." + selector.variable] = selector.value_selector
for selector_iter in selectors:
mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector
return mapping

View File

@@ -361,13 +361,16 @@ class IterationNode(BaseNode[IterationNodeData]):
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
if self.node_data.is_parallel:
metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
else:
metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index
metadata = {
**metadata,
NodeRunMetadataKey.ITERATION_ID: self.node_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID
if self.node_data.is_parallel
else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id
if self.node_data.is_parallel
else iter_run_index,
}
event.route_node_state.node_run_result.metadata = metadata
return event

View File

@@ -147,6 +147,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
planning_strategy=planning_strategy,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
@@ -157,6 +159,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
reranking_model = None
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
@@ -180,7 +184,9 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
available_datasets=available_datasets,
query=query,
top_k=node_data.multiple_retrieval_config.top_k,
score_threshold=node_data.multiple_retrieval_config.score_threshold,
score_threshold=node_data.multiple_retrieval_config.score_threshold
if node_data.multiple_retrieval_config.score_threshold is not None
else 0.0,
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model=reranking_model,
weights=weights,
@@ -205,7 +211,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
"content": item.page_content,
}
retrieval_resource_list.append(source)
document_score_list = {}
document_score_list: dict[str, float] = {}
# deal with dify documents
if dify_documents:
document_score_list = {}
@@ -260,7 +266,9 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
retrieval_resource_list.append(source)
if retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list, key=lambda x: x.get("metadata").get("score") or 0.0, reverse=True
retrieval_resource_list,
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
reverse=True,
)
position = 1
for item in retrieval_resource_list:
@@ -295,6 +303,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
:param node_data: node data
:return:
"""
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
model_name = node_data.single_retrieval_config.model.name
provider_name = node_data.single_retrieval_config.model.provider

View File

@@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence
from typing import Literal, Union
from typing import Any, Literal, Union
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@@ -17,9 +17,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
_node_type = NodeType.LIST_OPERATOR
def _run(self):
inputs = {}
process_data = {}
outputs = {}
inputs: dict[str, list] = {}
process_data: dict[str, list] = {}
outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
if variable is None:
@@ -93,6 +93,8 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
filter_func: Callable[[Any], bool]
result: list[Any] = []
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
@@ -236,6 +238,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
@@ -249,47 +252,47 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
raise InvalidKeyError(f"Invalid key: {key}")
def _contains(value: str):
def _contains(value: str) -> Callable[[str], bool]:
return lambda x: value in x
def _startswith(value: str):
def _startswith(value: str) -> Callable[[str], bool]:
return lambda x: x.startswith(value)
def _endswith(value: str):
def _endswith(value: str) -> Callable[[str], bool]:
return lambda x: x.endswith(value)
def _is(value: str):
def _is(value: str) -> Callable[[str], bool]:
return lambda x: x is value
def _in(value: str | Sequence[str]):
def _in(value: str | Sequence[str]) -> Callable[[str], bool]:
return lambda x: x in value
def _eq(value: int | float):
def _eq(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x == value
def _ne(value: int | float):
def _ne(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x != value
def _lt(value: int | float):
def _lt(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x < value
def _le(value: int | float):
def _le(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x <= value
def _gt(value: int | float):
def _gt(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x > value
def _ge(value: int | float):
def _ge(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x >= value
@@ -302,6 +305,7 @@ def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
extract_func: Callable[[File], Any]
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")

View File

@@ -88,8 +88,8 @@ class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]:
node_inputs = None
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
node_inputs: Optional[dict[str, Any]] = None
process_data = None
try:
@@ -196,7 +196,6 @@ class LLMNode(BaseNode[LLMNodeData]):
error_type=type(e).__name__,
)
)
return
except Exception as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
@@ -206,7 +205,6 @@ class LLMNode(BaseNode[LLMNodeData]):
process_data=process_data,
)
)
return
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
@@ -302,7 +300,7 @@ class LLMNode(BaseNode[LLMNodeData]):
return messages
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables = {}
variables: dict[str, Any] = {}
if not node_data.prompt_config:
return variables
@@ -319,7 +317,7 @@ class LLMNode(BaseNode[LLMNodeData]):
"""
# check if it's a context structure
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
return input_dict["content"]
return str(input_dict["content"])
# else, parse the dict
try:
@@ -557,7 +555,8 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages = []
# FIXME: fix the type error cause prompt_messages is type quick a few times
prompt_messages: list[Any] = []
if isinstance(prompt_template, list):
# For chat model
@@ -783,7 +782,7 @@ class LLMNode(BaseNode[LLMNodeData]):
else:
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
variable_mapping = {}
variable_mapping: dict[str, Any] = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
@@ -981,7 +980,7 @@ def _handle_memory_chat_mode(
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
memory_messages = []
memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)

View File

@@ -14,8 +14,8 @@ class LoopNode(BaseNode[LoopNodeData]):
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
def _run(self) -> LoopState:
return super()._run()
def _run(self) -> LoopState: # type: ignore
return super()._run() # type: ignore
@classmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
@@ -28,7 +28,7 @@ class LoopNode(BaseNode[LoopNodeData]):
# TODO waiting for implementation
return [
Condition(
Condition( # type: ignore
variable_selector=[node_id, "index"],
comparison_operator="",
value_type="value_selector",

View File

@@ -25,7 +25,7 @@ class ParameterConfig(BaseModel):
raise ValueError("Parameter name is required")
if value in {"__reason", "__is_success"}:
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
return value
return str(value)
class ParameterExtractorNodeData(BaseNodeData):
@@ -52,7 +52,7 @@ class ParameterExtractorNodeData(BaseNodeData):
:return: parameter json schema
"""
parameters = {"type": "object", "properties": {}, "required": []}
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}

View File

@@ -63,7 +63,8 @@ class ParameterExtractorNode(LLMNode):
Parameter Extractor Node.
"""
_node_data_cls = ParameterExtractorNodeData
# FIXME: figure out why here is different from super class
_node_data_cls = ParameterExtractorNodeData # type: ignore
_node_type = NodeType.PARAMETER_EXTRACTOR
_model_instance: Optional[ModelInstance] = None
@@ -253,6 +254,9 @@ class ParameterExtractorNode(LLMNode):
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
if text is None:
text = ""
return text, usage, tool_call
def _generate_function_call_prompt(
@@ -605,9 +609,10 @@ class ParameterExtractorNode(LLMNode):
json_str = extract_json(result[idx:])
if json_str:
try:
return json.loads(json_str)
return cast(dict, json.loads(json_str))
except Exception:
pass
return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
"""
@@ -616,13 +621,13 @@ class ParameterExtractorNode(LLMNode):
if not tool_call or not tool_call.function.arguments:
return None
return json.loads(tool_call.function.arguments)
return cast(dict, json.loads(tool_call.function.arguments))
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
"""
Generate default result.
"""
result = {}
result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.type == "number":
result[parameter.name] = 0
@@ -772,7 +777,7 @@ class ParameterExtractorNode(LLMNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ParameterExtractorNodeData,
node_data: ParameterExtractorNodeData, # type: ignore
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -781,6 +786,7 @@ class ParameterExtractorNode(LLMNode):
:param node_data: node data
:return:
"""
# FIXME: fix the type error later
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
if node_data.instruction:

View File

@@ -1,3 +1,5 @@
from typing import Any
FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters"
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
@@ -35,7 +37,7 @@ FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information fr
</structure>
""" # noqa: E501
FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [
FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [
{
"user": {
"query": "What is the weather today in SF?",

View File

@@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -34,12 +34,9 @@ from .template_prompts import (
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
if TYPE_CHECKING:
from core.file import File
class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData
_node_data_cls = QuestionClassifierNodeData # type: ignore
_node_type = NodeType.QUESTION_CLASSIFIER
def _run(self):
@@ -61,7 +58,7 @@ class QuestionClassifierNode(LLMNode):
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
files: Sequence[File] = (
files = (
self._fetch_files(
selector=node_data.vision.configs.variable_selector,
)
@@ -168,7 +165,7 @@ class QuestionClassifierNode(LLMNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: QuestionClassifierNodeData,
node_data: Any,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -177,6 +174,7 @@ class QuestionClassifierNode(LLMNode):
:param node_data: node data
:return:
"""
node_data = cast(QuestionClassifierNodeData, node_data)
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors = []
if node_data.instruction:

View File

@@ -9,7 +9,6 @@ from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCal
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
@@ -46,6 +45,8 @@ class ToolNode(BaseNode[ToolNodeData]):
# get tool runtime
try:
from core.tools.tool_manager import ToolManager
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
)
@@ -142,7 +143,7 @@ class ToolNode(BaseNode[ToolNodeData]):
"""
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
result = {}
result: dict[str, Any] = {}
for parameter_name in node_data.tool_parameters:
parameter = tool_parameters_dictionary.get(parameter_name)
if not parameter:
@@ -264,9 +265,9 @@ class ToolNode(BaseNode[ToolNodeData]):
"""
return "\n".join(
[
f"{message.message}"
str(message.message)
if message.type == ToolInvokeMessage.MessageType.TEXT
else f"Link: {message.message}"
else f"Link: {str(message.message)}"
for message in tool_response
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}
]

View File

@@ -36,6 +36,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
if income_value is None:
raise VariableOperatorNodeError("income value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:

View File

@@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, cast
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
@@ -29,7 +29,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump()
process_data = {}
process_data: dict[str, Any] = {}
# NOTE: This node has no outputs
updated_variables: list[Variable] = []
@@ -119,7 +119,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
else:
conversation_id = conversation_id.value
common_helpers.update_conversation_variable(
conversation_id=conversation_id,
conversation_id=cast(str, conversation_id),
variable=variable,
)