feat: Parallel Execution of Nodes in Workflows (#8192)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
takatost
2024-09-10 15:23:16 +08:00
committed by GitHub
parent 5da0182800
commit dabfd74622
156 changed files with 11158 additions and 5605 deletions

View File

@@ -26,7 +26,7 @@ class ToolNode(BaseNode):
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run the tool node
"""
@@ -56,8 +56,8 @@ class ToolNode(BaseNode):
# get parameters
tool_parameters = tool_runtime.get_runtime_parameters() or []
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data)
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data, for_log=True)
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data)
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True)
try:
messages = ToolEngine.workflow_invoke(
@@ -66,6 +66,7 @@ class ToolNode(BaseNode):
user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
thread_pool_id=self.thread_pool_id,
)
except Exception as e:
return NodeRunResult(
@@ -145,7 +146,8 @@ class ToolNode(BaseNode):
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\
-> tuple[str, list[FileVar], list[dict]]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
@@ -221,9 +223,16 @@ class ToolNode(BaseNode):
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ToolNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
@@ -239,4 +248,8 @@ class ToolNode(BaseNode):
elif input.type == 'constant':
pass
result = {
node_id + '.' + key: value for key, value in result.items()
}
return result