feat: Parallel Execution of Nodes in Workflows (#8192)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
takatost
2024-09-10 15:23:16 +08:00
committed by GitHub
parent 5da0182800
commit dabfd74622
156 changed files with 11158 additions and 5605 deletions

View File

@@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel):
desc: Optional[str] = None
class BaseIterationNodeData(BaseNodeData):
start_node_id: str
start_node_id: Optional[str] = None
class BaseIterationState(BaseModel):
iteration_node_id: str

View File

@@ -1,9 +1,9 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMUsage
from models import WorkflowNodeExecutionStatus
@@ -28,6 +28,7 @@ class NodeType(Enum):
VARIABLE_ASSIGNER = 'variable-assigner'
LOOP = 'loop'
ITERATION = 'iteration'
ITERATION_START = 'iteration-start' # fake start node for iteration
PARAMETER_EXTRACTOR = 'parameter-extractor'
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
@@ -56,6 +57,10 @@ class NodeRunMetadataKey(Enum):
TOOL_INFO = 'tool_info'
ITERATION_ID = 'iteration_id'
ITERATION_INDEX = 'iteration_index'
PARALLEL_ID = 'parallel_id'
PARALLEL_START_NODE_ID = 'parallel_start_node_id'
PARENT_PARALLEL_ID = 'parent_parallel_id'
PARENT_PARALLEL_START_NODE_ID = 'parent_parallel_start_node_id'
class NodeRunResult(BaseModel):
@@ -65,11 +70,32 @@ class NodeRunResult(BaseModel):
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict] = None # process data
outputs: Optional[Mapping[str, Any]] = None # node outputs
inputs: Optional[dict[str, Any]] = None # node inputs
process_data: Optional[dict[str, Any]] = None # process data
outputs: Optional[dict[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
llm_usage: Optional[LLMUsage] = None # llm usage
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")

View File

@@ -2,6 +2,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union
from pydantic import BaseModel, Field, model_validator
from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory
@@ -16,43 +17,52 @@ ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
class VariablePool:
def __init__(
self,
system_variables: Mapping[SystemVariableKey, Any],
user_inputs: Mapping[str, Any],
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable] | None = None,
) -> None:
# system variables
# for example:
# {
# 'query': 'abc',
# 'files': []
# }
class VariablePool(BaseModel):
# Variable dictionary is a dictionary for looking up variables by their selector.
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: dict[str, dict[int, Segment]] = Field(
description='Variables mapping',
default=defaultdict(dict)
)
# Variable dictionary is a dictionary for looking up variables by their selector.
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
# TODO: This user inputs is not used for pool.
user_inputs: Mapping[str, Any] = Field(
description='User inputs',
)
# TODO: This user inputs is not used for pool.
self.user_inputs = user_inputs
system_variables: Mapping[SystemVariableKey, Any] = Field(
description='System variables',
)
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
default_factory=list
)
conversation_variables: Sequence[Variable] | None = None
@model_validator(mode="after")
def val_model_after(self):
"""
Append system variables
:return:
"""
# Add system variables to the variable pool
self.system_variables = system_variables
for key, value in system_variables.items():
for key, value in self.system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Add environment variables to the variable pool
for var in environment_variables:
for var in self.environment_variables or []:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
# Add conversation variables to the variable pool
for var in conversation_variables or []:
for var in self.conversation_variables or []:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
return self
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Adds a variable to the variable pool.
@@ -79,7 +89,7 @@ class VariablePool:
v = factory.build_segment(value)
hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]][hash_key] = v
self.variable_dictionary[selector[0]][hash_key] = v
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
@@ -97,7 +107,7 @@ class VariablePool:
if len(selector) < 2:
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
value = self.variable_dictionary[selector[0]].get(hash_key)
return value
@@ -118,7 +128,7 @@ class VariablePool:
if len(selector) < 2:
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
value = self.variable_dictionary[selector[0]].get(hash_key)
return value.to_object() if value else None
def remove(self, selector: Sequence[str], /):
@@ -134,7 +144,19 @@ class VariablePool:
if not selector:
return
if len(selector) == 1:
self._variable_dictionary[selector[0]] = {}
self.variable_dictionary[selector[0]] = {}
return
hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]].pop(hash_key, None)
self.variable_dictionary[selector[0]].pop(hash_key, None)
def remove_node(self, node_id: str, /):
"""
Remove all variables associated with a given node id.
Args:
node_id (str): The node id to remove.
Returns:
None
"""
self.variable_dictionary.pop(node_id, None)

View File

@@ -66,8 +66,7 @@ class WorkflowRunState:
self.variable_pool = variable_pool
self.total_tokens = 0
self.workflow_nodes_and_results = []
self.current_iteration_state = None
self.workflow_node_steps = 1
self.workflow_node_runs = []
self.workflow_node_runs = []
self.current_iteration_state = None