feat: Parallel Execution of Nodes in Workflows (#8192)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
takatost
2024-09-10 15:23:16 +08:00
committed by GitHub
parent 5da0182800
commit dabfd74622
156 changed files with 11158 additions and 5605 deletions

View File

@@ -1,9 +1,8 @@
from typing import cast
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
GenerateRouteChunk,
@@ -19,24 +18,26 @@ class AnswerNode(BaseNode):
_node_data_cls = AnswerNodeData
_node_type: NodeType = NodeType.ANSWER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(AnswerNodeData, node_data)
# generate routes
generate_routes = self.extract_generate_route_from_node_data(node_data)
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
answer = ''
for part in generate_routes:
if part.type == "var":
if part.type == GenerateRouteChunk.ChunkType.VAR:
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = variable_pool.get(value_selector)
value = self.graph_runtime_state.variable_pool.get(
value_selector
)
if value:
answer += value.markdown
else:
@@ -51,70 +52,16 @@ class AnswerNode(BaseNode):
)
@classmethod
def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
node_data = cast(AnswerNodeData, node_data)
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
generate_routes = []
for part in template.split('Ω'):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))
return generate_routes
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AnswerNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
@@ -126,6 +73,6 @@ class AnswerNode(BaseNode):
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
return variable_mapping

View File

@@ -0,0 +1,169 @@
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
AnswerStreamGenerateRoute,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerStreamGeneratorRouter:
@classmethod
def init(cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
) -> AnswerStreamGenerateRoute:
"""
Get stream generate routes.
:return:
"""
# parse stream output node value selectors of answer nodes
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
for answer_node_id, node_config in node_id_config_mapping.items():
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
continue
# get generate route for stream output
generate_route = cls._extract_generate_route_selectors(node_config)
answer_generate_route[answer_node_id] = generate_route
# fetch answer dependencies
answer_node_ids = list(answer_generate_route.keys())
answer_dependencies = cls._fetch_answers_dependencies(
answer_node_ids=answer_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping
)
return AnswerStreamGenerateRoute(
answer_generate_route=answer_generate_route,
answer_dependencies=answer_dependencies
)
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
generate_routes: list[GenerateRouteChunk] = []
for part in template.split('Ω'):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))
return generate_routes
@classmethod
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = AnswerNodeData(**config.get("data", {}))
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys
@classmethod
def _fetch_answers_dependencies(cls,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict]
) -> dict[str, list[str]]:
"""
Fetch answer dependencies
:param answer_node_ids: answer node ids
:param reverse_edge_mapping: reverse edge mapping
:param node_id_config_mapping: node id config mapping
:return:
"""
answer_dependencies: dict[str, list[str]] = {}
for answer_node_id in answer_node_ids:
if answer_dependencies.get(answer_node_id) is None:
answer_dependencies[answer_node_id] = []
cls._recursive_fetch_answer_dependencies(
current_node_id=answer_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies
)
return answer_dependencies
@classmethod
def _recursive_fetch_answer_dependencies(cls,
current_node_id: str,
answer_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]]
) -> None:
"""
Recursive fetch answer dependencies
:param current_node_id: current node id
:param answer_node_id: answer node id
:param node_id_config_mapping: node id config mapping
:param reverse_edge_mapping: reverse edge mapping
:param answer_dependencies: answer dependencies
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
if source_node_type in (
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
):
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._recursive_fetch_answer_dependencies(
current_node_id=source_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies
)

View File

@@ -0,0 +1,221 @@
import logging
from collections.abc import Generator
from typing import Optional, cast
from core.file.file_obj import FileVar
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
logger = logging.getLogger(__name__)
class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
self.generate_routes = graph.answer_stream_generate_routes
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
self.reset()
yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id:
yield event
continue
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
]
else:
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
] = stream_out_answer_node_ids
for _ in stream_out_answer_node_ids:
yield event
elif isinstance(event, NodeRunSucceededEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
self.route_position[answer_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
# remove unreachable nodes
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(event)
else:
yield event
def reset(self) -> None:
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
:return:
"""
for answer_node_id, position in self.route_position.items():
# all depends on answer node id not in rest node ids
if (event.route_node_state.node_id != answer_node_id
and (answer_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
continue
route_position = self.route_position[answer_node_id]
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
for route_chunk in route_chunks:
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=route_chunk.text,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
break
value = self.variable_pool.get(
value_selector
)
if value is None:
break
text = value.markdown
if text:
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=value_selector,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
self.route_position[answer_node_id] += 1
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.from_variable_selector:
return []
stream_output_value_selector = event.from_variable_selector
if not stream_output_value_selector:
return []
stream_out_answer_node_ids = []
for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:
continue
# all depends on answer node id not in rest node ids
if all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
continue
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
continue
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
continue
stream_out_answer_node_ids.append(answer_node_id)
return stream_out_answer_node_ids
@classmethod
def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]:
"""
Fetch files from variable value
:param value: variable value
:return:
"""
if not value:
return []
files = []
if isinstance(value, list):
for item in value:
file_var = cls._get_file_var_from_value(item)
if file_var:
files.append(file_var)
elif isinstance(value, dict):
file_var = cls._get_file_var_from_value(value)
if file_var:
files.append(file_var)
return files
@classmethod
def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]:
"""
Get file var from value
:param value: variable value
:return:
"""
if not value:
return None
if isinstance(value, dict):
if '__variant' in value and value['__variant'] == FileVar.__name__:
return value
elif isinstance(value, FileVar):
return value.to_dict()
return None

View File

@@ -0,0 +1,71 @@
from abc import ABC, abstractmethod
from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
self.rest_node_ids = graph.node_ids.copy()
@abstractmethod
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)

View File

@@ -1,5 +1,6 @@
from enum import Enum
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
@@ -8,27 +9,54 @@ class AnswerNodeData(BaseNodeData):
"""
Answer Node Data.
"""
answer: str
answer: str = Field(..., description="answer template string")
class GenerateRouteChunk(BaseModel):
"""
Generate Route Chunk.
"""
type: str
class ChunkType(Enum):
VAR = "var"
TEXT = "text"
type: ChunkType = Field(..., description="generate route chunk type")
class VarGenerateRouteChunk(GenerateRouteChunk):
"""
Var Generate Route Chunk.
"""
type: str = "var"
value_selector: list[str]
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
"""generate route chunk type"""
value_selector: list[str] = Field(..., description="value selector")
class TextGenerateRouteChunk(GenerateRouteChunk):
"""
Text Generate Route Chunk.
"""
type: str = "text"
text: str
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT
"""generate route chunk type"""
text: str = Field(..., description="text")
class AnswerNodeDoubleLink(BaseModel):
node_id: str = Field(..., description="node id")
source_node_ids: list[str] = Field(..., description="source node ids")
target_node_ids: list[str] = Field(..., description="target node ids")
class AnswerStreamGenerateRoute(BaseModel):
"""
AnswerStreamGenerateRoute entity
"""
answer_dependencies: dict[str, list[str]] = Field(
...,
description="answer dependencies (answer node id -> dependent answer node ids)"
)
answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
...,
description="answer generate route (answer node id -> generate route chunks)"
)

View File

@@ -1,142 +1,103 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from enum import Enum
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from models import WorkflowNodeExecutionStatus
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
class BaseNode(ABC):
_node_data_cls: type[BaseNodeData]
_node_type: NodeType
tenant_id: str
app_id: str
workflow_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
workflow_call_depth: int
node_id: str
node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None
callbacks: Sequence[WorkflowCallback]
is_answer_previous_node: bool = False
def __init__(self, tenant_id: str,
app_id: str,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
def __init__(self,
id: str,
config: Mapping[str, Any],
callbacks: Sequence[WorkflowCallback] | None = None,
workflow_call_depth: int = 0) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
self.workflow_id = workflow_id
self.user_id = user_id
self.user_from = user_from
self.invoke_from = invoke_from
self.workflow_call_depth = workflow_call_depth
graph_init_params: GraphInitParams,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None) -> None:
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
self.workflow_type = graph_init_params.workflow_type
self.workflow_id = graph_init_params.workflow_id
self.graph_config = graph_init_params.graph_config
self.user_id = graph_init_params.user_id
self.user_from = graph_init_params.user_from
self.invoke_from = graph_init_params.invoke_from
self.workflow_call_depth = graph_init_params.call_depth
self.graph = graph
self.graph_runtime_state = graph_runtime_state
self.previous_node_id = previous_node_id
self.thread_pool_id = thread_pool_id
# TODO: May need to check if key exists.
self.node_id = config["id"]
if not self.node_id:
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required.")
self.node_id = node_id
self.node_data = self._node_data_cls(**config.get("data", {}))
self.callbacks = callbacks or []
@abstractmethod
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) \
-> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:param variable_pool: variable pool
:return:
"""
raise NotImplementedError
def run(self, variable_pool: VariablePool) -> NodeRunResult:
def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node entry
:param variable_pool: variable pool
:return:
"""
try:
result = self._run(
variable_pool=variable_pool
)
self.node_run_result = result
return result
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
result = self._run()
def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None:
"""
Publish text chunk
:param text: chunk text
:param value_selector: value selector
:return:
"""
if self.callbacks:
for callback in self.callbacks:
callback.on_node_text_chunk(
node_id=self.node_id,
text=text,
metadata={
"node_type": self.node_type,
"is_answer_previous_node": self.is_answer_previous_node,
"value_selector": value_selector
}
)
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(
run_result=result
)
else:
yield from result
@classmethod
def extract_variable_selector_to_variable_mapping(cls, config: dict):
def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param config: node config
:return:
"""
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(node_data)
return cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config,
node_id=node_id,
node_data=node_data
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: BaseNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
@@ -158,38 +119,3 @@ class BaseNode(ABC):
:return:
"""
return self._node_type
class BaseIterationNode(BaseNode):
@abstractmethod
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
"""
Run node
:param variable_pool: variable pool
:return:
"""
raise NotImplementedError
def run(self, variable_pool: VariablePool) -> BaseIterationState:
"""
Run node entry
:param variable_pool: variable pool
:return:
"""
return self._run(variable_pool=variable_pool)
def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
return self._get_next_iteration(variable_pool, state)
@abstractmethod
def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
raise NotImplementedError

View File

@@ -1,4 +1,5 @@
from typing import Optional, Union, cast
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union, cast
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
@@ -6,7 +7,6 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.code.entities import CodeNodeData
from models.workflow import WorkflowNodeExecutionStatus
@@ -33,13 +33,13 @@ class CodeNode(BaseNode):
return code_provider.get_default_config()
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run code
:param variable_pool: variable pool
:return:
"""
node_data = cast(CodeNodeData, self.node_data)
node_data = self.node_data
node_data = cast(CodeNodeData, node_data)
# Get code language
code_language = node_data.code_language
@@ -49,7 +49,7 @@ class CodeNode(BaseNode):
variables = {}
for variable_selector in node_data.variables:
variable = variable_selector.variable
value = variable_pool.get_any(variable_selector.value_selector)
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
variables[variable] = value
# Run code
@@ -311,13 +311,19 @@ class CodeNode(BaseNode):
return transformed_result
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: CodeNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
}

View File

@@ -1,8 +1,7 @@
from typing import cast
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.entities import EndNodeData
from models.workflow import WorkflowNodeExecutionStatus
@@ -12,10 +11,9 @@ class EndNode(BaseNode):
_node_data_cls = EndNodeData
_node_type = NodeType.END
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
@@ -24,7 +22,7 @@ class EndNode(BaseNode):
outputs = {}
for variable_selector in output_variables:
value = variable_pool.get_any(variable_selector.value_selector)
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
outputs[variable_selector.variable] = value
return NodeRunResult(
@@ -34,52 +32,16 @@ class EndNode(BaseNode):
)
@classmethod
def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]:
"""
Extract generate nodes
:param graph: graph
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
node_data = cast(EndNodeData, node_data)
return cls.extract_generate_nodes_from_node_data(graph, node_data)
@classmethod
def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]:
"""
Extract generate nodes from node data
:param graph: graph
:param node_data: node data object
:return:
"""
nodes = graph.get('nodes', [])
node_mapping = {node.get('id'): node for node in nodes}
variable_selectors = node_data.outputs
generate_nodes = []
for variable_selector in variable_selectors:
if not variable_selector.value_selector:
continue
node_id = variable_selector.value_selector[0]
if node_id != 'sys' and node_id in node_mapping:
node = node_mapping[node_id]
node_type = node.get('data', {}).get('type')
if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text':
generate_nodes.append(node_id)
# remove duplicates
generate_nodes = list(set(generate_nodes))
return generate_nodes
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: EndNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""

View File

@@ -0,0 +1,148 @@
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
class EndStreamGeneratorRouter:
@classmethod
def init(cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_parallel_mapping: dict[str, str]
) -> EndStreamParam:
"""
Get stream generate routes.
:return:
"""
# parse stream output node value selector of end nodes
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
for end_node_id, node_config in node_id_config_mapping.items():
if not node_config.get('data', {}).get('type') == NodeType.END.value:
continue
# skip end node in parallel
if end_node_id in node_parallel_mapping:
continue
# get generate route for stream output
stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config)
end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors
# fetch end dependencies
end_node_ids = list(end_stream_variable_selectors_mapping.keys())
end_dependencies = cls._fetch_ends_dependencies(
end_node_ids=end_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping
)
return EndStreamParam(
end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping,
end_dependencies=end_dependencies
)
@classmethod
def extract_stream_variable_selector_from_node_data(cls,
node_id_config_mapping: dict[str, dict],
node_data: EndNodeData) -> list[list[str]]:
"""
Extract stream variable selector from node data
:param node_id_config_mapping: node id config mapping
:param node_data: node data object
:return:
"""
variable_selectors = node_data.outputs
value_selectors = []
for variable_selector in variable_selectors:
if not variable_selector.value_selector:
continue
node_id = variable_selector.value_selector[0]
if node_id != 'sys' and node_id in node_id_config_mapping:
node = node_id_config_mapping[node_id]
node_type = node.get('data', {}).get('type')
if (
variable_selector.value_selector not in value_selectors
and node_type == NodeType.LLM.value
and variable_selector.value_selector[1] == 'text'
):
value_selectors.append(variable_selector.value_selector)
return value_selectors
@classmethod
def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \
-> list[list[str]]:
"""
Extract stream variable selector from node config
:param node_id_config_mapping: node id config mapping
:param config: node config
:return:
"""
node_data = EndNodeData(**config.get("data", {}))
return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data)
@classmethod
def _fetch_ends_dependencies(cls,
end_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict]
) -> dict[str, list[str]]:
"""
Fetch end dependencies
:param end_node_ids: end node ids
:param reverse_edge_mapping: reverse edge mapping
:param node_id_config_mapping: node id config mapping
:return:
"""
end_dependencies: dict[str, list[str]] = {}
for end_node_id in end_node_ids:
if end_dependencies.get(end_node_id) is None:
end_dependencies[end_node_id] = []
cls._recursive_fetch_end_dependencies(
current_node_id=end_node_id,
end_node_id=end_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
end_dependencies=end_dependencies
)
return end_dependencies
@classmethod
def _recursive_fetch_end_dependencies(cls,
current_node_id: str,
end_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]],
# type: ignore[name-defined]
end_dependencies: dict[str, list[str]]
) -> None:
"""
Recursive fetch end dependencies
:param current_node_id: current node id
:param end_node_id: end node id
:param node_id_config_mapping: node id config mapping
:param reverse_edge_mapping: reverse edge mapping
:param end_dependencies: end dependencies
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
if source_node_type in (
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
):
end_dependencies[end_node_id].append(source_node_id)
else:
cls._recursive_fetch_end_dependencies(
current_node_id=source_node_id,
end_node_id=end_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
end_dependencies=end_dependencies
)

View File

@@ -0,0 +1,191 @@
import logging
from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
class EndStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
self.end_stream_param = graph.end_stream_param
self.route_position = {}
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
self.route_position[end_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
self.has_outputed = False
self.outputed_node_ids = set()
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
self.reset()
yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id:
if self.has_outputed and event.node_id not in self.outputed_node_ids:
event.chunk_content = '\n' + event.chunk_content
self.outputed_node_ids.add(event.node_id)
self.has_outputed = True
yield event
continue
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
]
else:
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
] = stream_out_end_node_ids
if stream_out_end_node_ids:
if self.has_outputed and event.node_id not in self.outputed_node_ids:
event.chunk_content = '\n' + event.chunk_content
self.outputed_node_ids.add(event.node_id)
self.has_outputed = True
yield event
elif isinstance(event, NodeRunSucceededEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished
for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
self.route_position[end_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
# remove unreachable nodes
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(event)
else:
yield event
def reset(self) -> None:
self.route_position = {}
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
self.route_position[end_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
:return:
"""
for end_node_id, position in self.route_position.items():
# all depends on end node id not in rest node ids
if (event.route_node_state.node_id != end_node_id
and (end_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.end_stream_param.end_dependencies[end_node_id]))):
continue
route_position = self.route_position[end_node_id]
position = 0
value_selectors = []
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
if position >= route_position:
value_selectors.append(current_value_selectors)
position += 1
for value_selector in value_selectors:
if not value_selector:
continue
value = self.variable_pool.get(
value_selector
)
if value is None:
break
text = value.markdown
if text:
current_node_id = value_selector[0]
if self.has_outputed and current_node_id not in self.outputed_node_ids:
text = '\n' + text
self.outputed_node_ids.add(current_node_id)
self.has_outputed = True
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=value_selector,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
self.route_position[end_node_id] += 1
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.from_variable_selector:
return []
stream_output_value_selector = event.from_variable_selector
if not stream_output_value_selector:
return []
stream_out_end_node_ids = []
for end_node_id, route_position in self.route_position.items():
if end_node_id not in self.rest_node_ids:
continue
# all depends on end node id not in rest node ids
if all(dep_id not in self.rest_node_ids
for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
continue
position = 0
value_selector = None
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
if position == route_position:
value_selector = current_value_selectors
break
position += 1
if not value_selector:
continue
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
continue
stream_out_end_node_ids.append(end_node_id)
return stream_out_end_node_ids

View File

@@ -1,3 +1,5 @@
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
@@ -7,3 +9,17 @@ class EndNodeData(BaseNodeData):
END Node Data.
"""
outputs: list[VariableSelector]
class EndStreamParam(BaseModel):
"""
EndStreamParam entity
"""
end_dependencies: dict[str, list[str]] = Field(
...,
description="end dependencies (end node id -> dependent node ids)"
)
end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field(
...,
description="end stream variable selector mapping (end node id -> stream variable selectors)"
)

View File

@@ -0,0 +1,20 @@
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
class RunCompletedEvent(BaseModel):
run_result: NodeRunResult = Field(..., description="run result")
class RunStreamChunkEvent(BaseModel):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: list[str] = Field(..., description="from variable selector")
class RunRetrieverResourceEvent(BaseModel):
retriever_resources: list[dict] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent

View File

@@ -1,15 +1,14 @@
import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_extension
from os import path
from typing import cast
from typing import Any, cast
from configs import dify_config
from core.app.segments import parser
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeData,
@@ -48,17 +47,22 @@ class HttpRequestNode(BaseNode):
},
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
# TODO: Switch to use segment directly
if node_data.authorization.config and node_data.authorization.config.api_key:
node_data.authorization.config.api_key = parser.convert_template(template=node_data.authorization.config.api_key, variable_pool=variable_pool).text
node_data.authorization.config.api_key = parser.convert_template(
template=node_data.authorization.config.api_key,
variable_pool=self.graph_runtime_state.variable_pool
).text
# init http executor
http_executor = None
try:
http_executor = HttpExecutor(
node_data=node_data, timeout=self._get_request_timeout(node_data), variable_pool=variable_pool
node_data=node_data,
timeout=self._get_request_timeout(node_data),
variable_pool=self.graph_runtime_state.variable_pool
)
# invoke http executor
@@ -102,13 +106,19 @@ class HttpRequestNode(BaseNode):
return timeout
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: HttpRequestNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
node_data = cast(HttpRequestNodeData, node_data)
try:
http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT)
@@ -116,7 +126,7 @@ class HttpRequestNode(BaseNode):
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
return variable_mapping
except Exception as e:

View File

@@ -3,20 +3,7 @@ from typing import Literal, Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
class Condition(BaseModel):
"""
Condition entity
"""
variable_selector: list[str]
comparison_operator: Literal[
# for string or array
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "regex match",
# for number
"=", "", ">", "<", "", "", "null", "not null"
]
value: Optional[str] = None
from core.workflow.utils.condition.entities import Condition
class IfElseNodeData(BaseNodeData):

View File

@@ -1,13 +1,10 @@
import re
from collections.abc import Sequence
from typing import Optional, cast
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.if_else.entities import Condition, IfElseNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
from models.workflow import WorkflowNodeExecutionStatus
@@ -15,31 +12,35 @@ class IfElseNode(BaseNode):
_node_data_cls = IfElseNodeData
_node_type = NodeType.IF_ELSE
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(IfElseNodeData, node_data)
node_inputs = {
node_inputs: dict[str, list] = {
"conditions": []
}
process_datas = {
process_datas: dict[str, list] = {
"condition_results": []
}
input_conditions = []
final_result = False
selected_case_id = None
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used
if node_data.cases:
for case in node_data.cases:
input_conditions, group_result = self.process_conditions(variable_pool, case.conditions)
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions
)
# Apply the logical operator for the current case
final_result = all(group_result) if case.logical_operator == "and" else any(group_result)
@@ -58,7 +59,10 @@ class IfElseNode(BaseNode):
else:
# Fallback to old structure if cases are not defined
input_conditions, group_result = self.process_conditions(variable_pool, node_data.conditions)
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=node_data.conditions
)
final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result)
@@ -94,376 +98,17 @@ class IfElseNode(BaseNode):
return data
def evaluate_condition(
self, actual_value: Optional[str | list], expected_value: str, comparison_operator: str
) -> bool:
"""
Evaluate condition
:param actual_value: actual value
:param expected_value: expected value
:param comparison_operator: comparison operator
:return: bool
"""
if comparison_operator == "contains":
return self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains":
return self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with":
return self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with":
return self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is":
return self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not":
return self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty":
return self._assert_empty(actual_value)
elif comparison_operator == "not empty":
return self._assert_not_empty(actual_value)
elif comparison_operator == "=":
return self._assert_equal(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">":
return self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<":
return self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null":
return self._assert_null(actual_value)
elif comparison_operator == "not null":
return self._assert_not_null(actual_value)
elif comparison_operator == "regex match":
return self._assert_regex_match(actual_value, expected_value)
else:
raise ValueError(f"Invalid comparison operator: {comparison_operator}")
def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
input_conditions = []
group_result = []
for condition in conditions:
actual_variable = variable_pool.get_any(condition.variable_selector)
if condition.value is not None:
variable_template_parser = VariableTemplateParser(template=condition.value)
expected_value = variable_template_parser.extract_variable_selectors()
variable_selectors = variable_template_parser.extract_variable_selectors()
if variable_selectors:
for variable_selector in variable_selectors:
value = variable_pool.get_any(variable_selector.value_selector)
expected_value = variable_template_parser.format({variable_selector.variable: value})
else:
expected_value = condition.value
else:
expected_value = None
comparison_operator = condition.comparison_operator
input_conditions.append(
{
"actual_value": actual_variable,
"expected_value": expected_value,
"comparison_operator": comparison_operator
}
)
result = self.evaluate_condition(actual_variable, expected_value, comparison_operator)
group_result.append(result)
return input_conditions, group_result
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value not in actual_value:
return False
return True
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert not contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return True
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value in actual_value:
return False
return True
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert start with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.startswith(expected_value):
return False
return True
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert end with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.endswith(expected_value):
return False
return True
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value != expected_value:
return False
return True
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is not
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value == expected_value:
return False
return True
def _assert_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert empty
:param actual_value: actual value
:return:
"""
if not actual_value:
return True
return False
def _assert_regex_match(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert empty
:param actual_value: actual value
:return:
"""
if actual_value is None:
return False
return re.search(expected_value, actual_value) is not None
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert not empty
:param actual_value: actual value
:return:
"""
if actual_value:
return True
return False
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value != expected_value:
return False
return True
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert not equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value == expected_value:
return False
return True
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value <= expected_value:
return False
return True
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value >= expected_value:
return False
return True
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value < expected_value:
return False
return True
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value > expected_value:
return False
return True
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert null
:param actual_value: actual value
:return:
"""
if actual_value is None:
return True
return False
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert not null
:param actual_value: actual value
:return:
"""
if actual_value is not None:
return True
return False
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IfElseNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""

View File

@@ -1,6 +1,6 @@
from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
class IterationNodeData(BaseIterationNodeData):
@@ -11,6 +11,13 @@ class IterationNodeData(BaseIterationNodeData):
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
class IterationStartNodeData(BaseNodeData):
"""
Iteration Start Node Data.
"""
pass
class IterationState(BaseIterationState):
"""
Iteration State.

View File

@@ -1,124 +1,371 @@
from typing import cast
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime, timezone
from typing import Any, cast
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
BaseParallelBranchEvent,
GraphRunFailedEvent,
InNodeEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.utils.condition.entities import Condition
from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)
class IterationNode(BaseIterationNode):
class IterationNode(BaseNode):
"""
Iteration Node.
"""
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run the node.
"""
self.node_data = cast(IterationNodeData, self.node_data)
iterator = variable_pool.get_any(self.node_data.iterator_selector)
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not isinstance(iterator, list):
raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={
'iterator_selector': iterator
}, outputs=[], metadata=IterationState.MetaData(
iterator_length=len(iterator) if iterator is not None else 0
))
if not iterator_list_segment:
raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found")
self._set_current_iteration_variable(variable_pool, state)
return state
iterator_list_value = iterator_list_segment.to_object()
def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
# resolve current output
self._resolve_current_output(variable_pool, state)
# move to next iteration
self._next_iteration(variable_pool, state)
if not isinstance(iterator_list_value, list):
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
node_data = cast(IterationNodeData, self.node_data)
if self._reached_iteration_limit(variable_pool, state):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs = {
"iterator_selector": iterator_list_value
}
graph_config = self.graph_config
if not self.node_data.start_node_id:
raise ValueError(f'field start_node_id in iteration {self.node_id} not found')
root_node_id = self.node_data.start_node_id
# init graph
iteration_graph = Graph.init(
graph_config=graph_config,
root_node_id=root_node_id
)
if not iteration_graph:
raise ValueError('iteration graph not found')
leaf_node_ids = iteration_graph.get_leaf_node_ids()
iteration_leaf_node_ids = []
for leaf_node_id in leaf_node_ids:
node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id)
if not node_config:
continue
leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id")
if not leaf_node_iteration_id:
continue
if leaf_node_iteration_id != self.node_id:
continue
iteration_leaf_node_ids.append(leaf_node_id)
# add condition of end nodes to root node
iteration_graph.add_extra_edge(
source_node_id=leaf_node_id,
target_node_id=root_node_id,
run_condition=RunCondition(
type="condition",
conditions=[
Condition(
variable_selector=[self.node_id, "index"],
comparison_operator="<",
value=str(len(iterator_list_value))
)
]
)
)
variable_pool = self.graph_runtime_state.variable_pool
# append iteration variable (item, index) to variable pool
variable_pool.add(
[self.node_id, 'index'],
0
)
variable_pool.add(
[self.node_id, 'item'],
iterator_list_value[0]
)
# init graph engine
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=iteration_graph,
graph_config=graph_config,
variable_pool=variable_pool,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
yield IterationRunStartedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
metadata={
"iterator_length": len(iterator_list_value)
},
predecessor_node_id=self.previous_node_id
)
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=0,
pre_iteration_output=None
)
outputs: list[Any] = []
try:
# run workflow
rst = graph_engine.run()
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START:
continue
if isinstance(event, NodeRunSucceededEvent):
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
event.route_node_state.node_run_result.metadata = metadata
yield event
# handle iteration run result
if event.route_node_state.node_id in iteration_leaf_node_ids:
# append to iteration output variable list
current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
outputs.append(current_iteration_output)
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove_node(node_id)
# move to next iteration
current_index = variable_pool.get([self.node_id, 'index'])
if current_index is None:
raise ValueError(f'iteration {self.node_id} current index not found')
next_index = int(current_index.to_object()) + 1
variable_pool.add(
[self.node_id, 'index'],
next_index
)
if next_index < len(iterator_list_value):
variable_pool.add(
[self.node_id, 'item'],
iterator_list_value[next_index]
)
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
pre_iteration_output=jsonable_encoder(
current_iteration_output) if current_iteration_output else None
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
break
else:
event = cast(InNodeEvent, event)
yield event
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
'output': jsonable_encoder(state.outputs)
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
}
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'output': jsonable_encoder(outputs)
}
)
)
except Exception as e:
# iteration run failed
logger.exception("Iteration run failed")
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
},
error=str(e),
)
return node_data.start_node_id
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
finally:
# remove iteration variable (item, index) from variable pool after iteration run completed
variable_pool.remove([self.node_id, 'index'])
variable_pool.remove([self.node_id, 'item'])
def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
"""
Set current iteration variable.
:variable_pool: variable pool
"""
node_data = cast(IterationNodeData, self.node_data)
variable_pool.add((self.node_id, 'index'), state.index)
# get the iterator value
iterator = variable_pool.get_any(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return
if state.index < len(iterator):
variable_pool.add((self.node_id, 'item'), iterator[state.index])
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
"""
Move to next iteration.
:param variable_pool: variable pool
"""
state.index += 1
self._set_current_iteration_variable(variable_pool, state)
def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
"""
Check if iteration limit is reached.
:return: True if iteration limit is reached, False otherwise
"""
node_data = cast(IterationNodeData, self.node_data)
iterator = variable_pool.get_any(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return True
return state.index >= len(iterator)
def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
"""
Resolve current output.
:param variable_pool: variable pool
"""
output_selector = cast(IterationNodeData, self.node_data).output_selector
output = variable_pool.get_any(output_selector)
# clear the output for this iteration
variable_pool.remove([self.node_id] + output_selector[1:])
state.current_output = output
if output is not None:
# NOTE: This is a temporary patch to process double nested list (for example, DALL-E output in iteration).
if isinstance(output, list):
state.outputs.extend(output)
else:
state.outputs.append(output)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {
'input_selector': node_data.iterator_selector,
}
variable_mapping = {
f'{node_id}.input_selector': node_data.iterator_selector,
}
# init graph
iteration_graph = Graph.init(
graph_config=graph_config,
root_node_id=node_data.start_node_id
)
if not iteration_graph:
raise ValueError('iteration graph not found')
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
if sub_node_config.get('data', {}).get('iteration_id') != node_id:
continue
# variable selector to variable mapping
try:
# Get node class
from core.workflow.nodes.node_mapping import node_classes
node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
if not node_cls:
continue
node_cls = cast(BaseNode, node_cls)
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config,
config=sub_node_config
)
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
except NotImplementedError:
sub_node_variable_mapping = {}
# remove iteration variables
sub_node_variable_mapping = {
sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items()
if value[0] != node_id
}
variable_mapping.update(sub_node_variable_mapping)
# remove variable out from iteration
variable_mapping = {
key: value for key, value in variable_mapping.items()
if value[0] not in iteration_graph.node_ids
}
return variable_mapping

View File

@@ -0,0 +1,39 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
from models.workflow import WorkflowNodeExecutionStatus
class IterationStartNode(BaseNode):
"""
Iteration Start Node.
"""
_node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START
def _run(self) -> NodeRunResult:
"""
Run the node.
"""
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}

View File

@@ -1,3 +1,5 @@
import logging
from collections.abc import Mapping, Sequence
from typing import Any, cast
from sqlalchemy import func
@@ -12,15 +14,15 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
@@ -37,11 +39,11 @@ class KnowledgeRetrievalNode(BaseNode):
_node_data_cls = KnowledgeRetrievalNodeData
node_type = NodeType.KNOWLEDGE_RETRIEVAL
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
def _run(self) -> NodeRunResult:
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
# extract variables
variable = variable_pool.get_any(node_data.query_variable_selector)
variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector)
query = variable
variables = {
'query': query
@@ -68,7 +70,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
except Exception as e:
logger.exception("Error when running knowledge retrieval node")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
@@ -235,11 +237,21 @@ class KnowledgeRetrievalNode(BaseNode):
return context_list
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: KnowledgeRetrievalNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
variable_mapping = {}
variable_mapping['query'] = node_data.query_variable_selector
variable_mapping[node_id + '.query'] = node_data.query_variable_selector
return variable_mapping
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[

View File

@@ -1,16 +1,17 @@
import json
from collections.abc import Generator
from collections.abc import Generator, Mapping, Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from pydantic import BaseModel
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
@@ -25,7 +26,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
@@ -43,17 +46,26 @@ if TYPE_CHECKING:
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage
finish_reason: Optional[str] = None
class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = cast(LLMNodeData, deepcopy(self.node_data))
variable_pool = self.graph_runtime_state.variable_pool
node_inputs = None
process_data = None
@@ -80,10 +92,15 @@ class LLMNode(BaseNode):
node_inputs['#files#'] = [file.to_dict() for file in files]
# fetch context value
context = self._fetch_context(node_data, variable_pool)
generator = self._fetch_context(node_data, variable_pool)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
context = event.context
yield event
if context:
node_inputs['#context#'] = context
node_inputs['#context#'] = context # type: ignore
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
@@ -115,19 +132,34 @@ class LLMNode(BaseNode):
}
# handle invoke result
result_text, usage, finish_reason = self._invoke_llm(
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
)
result_text = ''
usage = LLMUsage.empty_usage()
finish_reason = None
for event in generator:
if isinstance(event, RunStreamChunkEvent):
yield event
elif isinstance(event, ModelInvokeCompleted):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data
)
)
return
outputs = {
'text': result_text,
@@ -135,22 +167,26 @@ class LLMNode(BaseNode):
'finish_reason': finish_reason
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
outputs=outputs,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
outputs=outputs,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
},
llm_usage=usage
)
)
def _invoke_llm(self, node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str]) -> tuple[str, LLMUsage]:
stop: Optional[list[str]] = None) \
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Invoke large language model
:param node_data_model: node data model
@@ -170,23 +206,31 @@ class LLMNode(BaseNode):
)
# handle invoke result
text, usage, finish_reason = self._handle_invoke_result(
generator = self._handle_invoke_result(
invoke_result=invoke_result
)
usage = LLMUsage.empty_usage()
for event in generator:
yield event
if isinstance(event, ModelInvokeCompleted):
usage = event.usage
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage, finish_reason
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
if isinstance(invoke_result, LLMResult):
return
model = None
prompt_messages = []
prompt_messages: list[PromptMessage] = []
full_text = ''
usage = None
finish_reason = None
@@ -194,7 +238,10 @@ class LLMNode(BaseNode):
text = result.delta.message.content
full_text += text
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
yield RunStreamChunkEvent(
chunk_content=text,
from_variable_selector=[self.node_id, 'text']
)
if not model:
model = result.model
@@ -211,11 +258,15 @@ class LLMNode(BaseNode):
if not usage:
usage = LLMUsage.empty_usage()
return full_text, usage, finish_reason
yield ModelInvokeCompleted(
text=full_text,
usage=usage,
finish_reason=finish_reason
)
def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
"""
Transform chat messages
@@ -224,13 +275,13 @@ class LLMNode(BaseNode):
"""
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == 'jinja2':
if messages.edition_type == 'jinja2' and messages.jinja2_text:
messages.text = messages.jinja2_text
return messages
for message in messages:
if message.edition_type == 'jinja2':
if message.edition_type == 'jinja2' and message.jinja2_text:
message.text = message.jinja2_text
return messages
@@ -348,7 +399,7 @@ class LLMNode(BaseNode):
return files
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]:
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
"""
Fetch context
:param node_data: node data
@@ -356,15 +407,18 @@ class LLMNode(BaseNode):
:return:
"""
if not node_data.context.enabled:
return None
return
if not node_data.context.variable_selector:
return None
return
context_value = variable_pool.get_any(node_data.context.variable_selector)
if context_value:
if isinstance(context_value, str):
return context_value
yield RunRetrieverResourceEvent(
retriever_resources=[],
context=context_value
)
elif isinstance(context_value, list):
context_str = ''
original_retriever_resource = []
@@ -381,17 +435,10 @@ class LLMNode(BaseNode):
if retriever_resource:
original_retriever_resource.append(retriever_resource)
if self.callbacks and original_retriever_resource:
for callback in self.callbacks:
callback.on_event(
event=QueueRetrieverResourcesEvent(
retriever_resources=original_retriever_resource
)
)
return context_str.strip()
return None
yield RunRetrieverResourceEvent(
retriever_resources=original_retriever_resource,
context=context_str.strip()
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
"""
@@ -574,7 +621,8 @@ class LLMNode(BaseNode):
if not isinstance(prompt_message.content, str):
prompt_message_content = []
for content_item in prompt_message.content:
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent):
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(
content_item, ImagePromptMessageContent):
# Override vision config if LLM node has vision config
if vision_detail:
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
@@ -646,13 +694,19 @@ class LLMNode(BaseNode):
db.session.commit()
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LLMNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
prompt_template = node_data.prompt_template
variable_selectors = []
@@ -702,6 +756,10 @@ class LLMNode(BaseNode):
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {
node_id + '.' + key: value for key, value in variable_mapping.items()
}
return variable_mapping
@classmethod

View File

@@ -1,20 +1,34 @@
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode
from typing import Any
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
from core.workflow.utils.condition.entities import Condition
class LoopNode(BaseIterationNode):
class LoopNode(BaseNode):
"""
Loop Node.
"""
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
def _run(self, variable_pool: VariablePool) -> LoopState:
return super()._run(variable_pool)
def _run(self) -> LoopState:
return super()._run()
def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
@classmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
"""
Get next iteration start node id based on the graph.
Get conditions.
"""
node_id = node_config.get('id')
if not node_id:
return []
# TODO waiting for implementation
return [Condition(
variable_selector=[node_id, 'index'],
comparison_operator="",
value_type="value_selector",
value_selector=[]
)]

View File

@@ -0,0 +1,37 @@
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from core.workflow.nodes.variable_assigner import VariableAssignerNode
node_classes = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.ITERATION_START: IterationStartNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
}

View File

@@ -1,6 +1,7 @@
import json
import uuid
from typing import Optional, cast
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -66,12 +67,12 @@ class ParameterExtractorNode(LLMNode):
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self.node_data)
variable = variable_pool.get_any(node_data.query)
variable = self.graph_runtime_state.variable_pool.get_any(node_data.query)
if not variable:
raise ValueError("Input variable content not found or is empty")
query = variable
@@ -92,17 +93,20 @@ class ParameterExtractorNode(LLMNode):
raise ValueError("Model schema not found")
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance)
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \
and node_data.reasoning_mode == 'function_call':
# use function call
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
node_data, query, variable_pool, model_config, memory
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
)
else:
# use prompt engineering
prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config,
prompt_messages = self._generate_prompt_engineering_prompt(node_data,
query,
self.graph_runtime_state.variable_pool,
model_config,
memory)
prompt_message_tools = []
@@ -172,7 +176,8 @@ class ParameterExtractorNode(LLMNode):
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
},
llm_usage=usage
)
def _invoke_llm(self, node_data_model: ModelConfig,
@@ -697,15 +702,19 @@ class ParameterExtractorNode(LLMNode):
return self._model_instance, self._model_config
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[
str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ParameterExtractorNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
node_data = node_data
variable_mapping = {
'query': node_data.query
}
@@ -715,4 +724,8 @@ class ParameterExtractorNode(LLMNode):
for selector in variable_template_parser.extract_variable_selectors():
variable_mapping[selector.variable] = selector.value_selector
variable_mapping = {
node_id + '.' + key: value for key, value in variable_mapping.items()
}
return variable_mapping

View File

@@ -1,10 +1,12 @@
import json
import logging
from typing import Optional, Union, cast
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -13,10 +15,9 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.question_classifier.template_prompts import (
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
@@ -36,9 +37,10 @@ class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData
node_type = NodeType.QUESTION_CLASSIFIER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
node_data = cast(QuestionClassifierNodeData, node_data)
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
variable = variable_pool.get(node_data.query_variable_selector)
@@ -63,12 +65,23 @@ class QuestionClassifierNode(LLMNode):
)
# handle invoke result
result_text, usage, finish_reason = self._invoke_llm(
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
)
result_text = ''
usage = LLMUsage.empty_usage()
finish_reason = None
for event in generator:
if isinstance(event, ModelInvokeCompleted):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
try:
@@ -109,7 +122,8 @@ class QuestionClassifierNode(LLMNode):
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
},
llm_usage=usage
)
except ValueError as e:
@@ -121,13 +135,24 @@ class QuestionClassifierNode(LLMNode):
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
},
llm_usage=usage
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: QuestionClassifierNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
variable_mapping = {'query': node_data.query_variable_selector}
variable_selectors = []
if node_data.instruction:
@@ -135,6 +160,11 @@ class QuestionClassifierNode(LLMNode):
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {
node_id + '.' + key: value for key, value in variable_mapping.items()
}
return variable_mapping
@classmethod

View File

@@ -1,7 +1,9 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
@@ -11,14 +13,13 @@ class StartNode(BaseNode):
_node_data_cls = StartNodeData
_node_type = NodeType.START
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_inputs = dict(variable_pool.user_inputs)
system_inputs = variable_pool.system_variables
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
@@ -30,9 +31,16 @@ class StartNode(BaseNode):
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: StartNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""

View File

@@ -1,15 +1,16 @@
import os
from typing import Optional, cast
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from models.workflow import WorkflowNodeExecutionStatus
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000'))
class TemplateTransformNode(BaseNode):
_node_data_cls = TemplateTransformNodeData
_node_type = NodeType.TEMPLATE_TRANSFORM
@@ -34,7 +35,7 @@ class TemplateTransformNode(BaseNode):
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run node
"""
@@ -45,7 +46,7 @@ class TemplateTransformNode(BaseNode):
variables = {}
for variable_selector in node_data.variables:
variable_name = variable_selector.variable
value = variable_pool.get_any(variable_selector.value_selector)
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
variables[variable_name] = value
# Run code
try:
@@ -60,7 +61,7 @@ class TemplateTransformNode(BaseNode):
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
)
if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
@@ -75,14 +76,21 @@ class TemplateTransformNode(BaseNode):
'output': result['result']
}
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: TemplateTransformNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
}
node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
}

View File

@@ -26,7 +26,7 @@ class ToolNode(BaseNode):
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run the tool node
"""
@@ -56,8 +56,8 @@ class ToolNode(BaseNode):
# get parameters
tool_parameters = tool_runtime.get_runtime_parameters() or []
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data)
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data, for_log=True)
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data)
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True)
try:
messages = ToolEngine.workflow_invoke(
@@ -66,6 +66,7 @@ class ToolNode(BaseNode):
user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
thread_pool_id=self.thread_pool_id,
)
except Exception as e:
return NodeRunResult(
@@ -145,7 +146,8 @@ class ToolNode(BaseNode):
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\
-> tuple[str, list[FileVar], list[dict]]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
@@ -221,9 +223,16 @@ class ToolNode(BaseNode):
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ToolNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
@@ -239,4 +248,8 @@ class ToolNode(BaseNode):
elif input.type == 'constant':
pass
result = {
node_id + '.' + key: value for key, value in result.items()
}
return result

View File

@@ -1,8 +1,7 @@
from typing import cast
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
from models.workflow import WorkflowNodeExecutionStatus
@@ -12,7 +11,7 @@ class VariableAggregatorNode(BaseNode):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
node_data = cast(VariableAssignerNodeData, self.node_data)
# Get variables
outputs = {}
@@ -20,7 +19,7 @@ class VariableAggregatorNode(BaseNode):
if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled:
for selector in node_data.variables:
variable = variable_pool.get_any(selector)
variable = self.graph_runtime_state.variable_pool.get_any(selector)
if variable is not None:
outputs = {
"output": variable
@@ -33,7 +32,7 @@ class VariableAggregatorNode(BaseNode):
else:
for group in node_data.advanced_settings.groups:
for selector in group.variables:
variable = variable_pool.get_any(selector)
variable = self.graph_runtime_state.variable_pool.get_any(selector)
if variable is not None:
outputs[group.group_name] = {
@@ -49,5 +48,17 @@ class VariableAggregatorNode(BaseNode):
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: VariableAssignerNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}

View File

@@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
from core.app.segments import SegmentType, Variable, factory
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from extensions.ext_database import db
from models import ConversationVariable, WorkflowNodeExecutionStatus
@@ -19,23 +18,23 @@ class VariableAssignerNode(BaseNode):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data)
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = variable_pool.get(data.assigned_variable_selector)
original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError('assigned variable not found')
match data.write_mode:
case WriteMode.OVER_WRITE:
income_value = variable_pool.get(data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_variable = original_variable.model_copy(update={'value': income_value.value})
case WriteMode.APPEND:
income_value = variable_pool.get(data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_value = original_variable.value + [income_value.value]
@@ -49,11 +48,11 @@ class VariableAssignerNode(BaseNode):
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
# Over write the variable.
variable_pool.add(data.assigned_variable_selector, updated_variable)
self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
conversation_id = variable_pool.get(['sys', 'conversation_id'])
conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id'])
if not conversation_id:
raise VariableAssignerNodeError('conversation_id not found')
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)