feat: Parallel Execution of Nodes in Workflows (#8192)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel):
|
||||
desc: Optional[str] = None
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: str
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
iteration_node_id: str
|
||||
|
@@ -1,9 +1,9 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from models import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ class NodeType(Enum):
|
||||
VARIABLE_ASSIGNER = 'variable-assigner'
|
||||
LOOP = 'loop'
|
||||
ITERATION = 'iteration'
|
||||
ITERATION_START = 'iteration-start' # fake start node for iteration
|
||||
PARAMETER_EXTRACTOR = 'parameter-extractor'
|
||||
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
|
||||
|
||||
@@ -56,6 +57,10 @@ class NodeRunMetadataKey(Enum):
|
||||
TOOL_INFO = 'tool_info'
|
||||
ITERATION_ID = 'iteration_id'
|
||||
ITERATION_INDEX = 'iteration_index'
|
||||
PARALLEL_ID = 'parallel_id'
|
||||
PARALLEL_START_NODE_ID = 'parallel_start_node_id'
|
||||
PARENT_PARALLEL_ID = 'parent_parallel_id'
|
||||
PARENT_PARALLEL_START_NODE_ID = 'parent_parallel_start_node_id'
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
@@ -65,11 +70,32 @@ class NodeRunResult(BaseModel):
|
||||
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[dict] = None # process data
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
inputs: Optional[dict[str, Any]] = None # node inputs
|
||||
process_data: Optional[dict[str, Any]] = None # process data
|
||||
outputs: Optional[dict[str, Any]] = None # node outputs
|
||||
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
llm_usage: Optional[LLMUsage] = None # llm usage
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
error: Optional[str] = None # error message if status is failed
|
||||
|
||||
|
||||
class UserFrom(Enum):
|
||||
"""
|
||||
User from
|
||||
"""
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "UserFrom":
|
||||
"""
|
||||
Value of
|
||||
:param value: value
|
||||
:return:
|
||||
"""
|
||||
for item in cls:
|
||||
if item.value == value:
|
||||
return item
|
||||
raise ValueError(f"Invalid value: {value}")
|
||||
|
@@ -2,6 +2,7 @@ from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.app.segments import Segment, Variable, factory
|
||||
@@ -16,43 +17,52 @@ ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
|
||||
|
||||
class VariablePool:
|
||||
def __init__(
|
||||
self,
|
||||
system_variables: Mapping[SystemVariableKey, Any],
|
||||
user_inputs: Mapping[str, Any],
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable] | None = None,
|
||||
) -> None:
|
||||
# system variables
|
||||
# for example:
|
||||
# {
|
||||
# 'query': 'abc',
|
||||
# 'files': []
|
||||
# }
|
||||
class VariablePool(BaseModel):
|
||||
# Variable dictionary is a dictionary for looking up variables by their selector.
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
variable_dictionary: dict[str, dict[int, Segment]] = Field(
|
||||
description='Variables mapping',
|
||||
default=defaultdict(dict)
|
||||
)
|
||||
|
||||
# Variable dictionary is a dictionary for looking up variables by their selector.
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
|
||||
# TODO: This user inputs is not used for pool.
|
||||
user_inputs: Mapping[str, Any] = Field(
|
||||
description='User inputs',
|
||||
)
|
||||
|
||||
# TODO: This user inputs is not used for pool.
|
||||
self.user_inputs = user_inputs
|
||||
system_variables: Mapping[SystemVariableKey, Any] = Field(
|
||||
description='System variables',
|
||||
)
|
||||
|
||||
environment_variables: Sequence[Variable] = Field(
|
||||
description="Environment variables.",
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
conversation_variables: Sequence[Variable] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def val_model_after(self):
|
||||
"""
|
||||
Append system variables
|
||||
:return:
|
||||
"""
|
||||
# Add system variables to the variable pool
|
||||
self.system_variables = system_variables
|
||||
for key, value in system_variables.items():
|
||||
for key, value in self.system_variables.items():
|
||||
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
|
||||
|
||||
# Add environment variables to the variable pool
|
||||
for var in environment_variables:
|
||||
for var in self.environment_variables or []:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
# Add conversation variables to the variable pool
|
||||
for var in conversation_variables or []:
|
||||
for var in self.conversation_variables or []:
|
||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
return self
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
"""
|
||||
Adds a variable to the variable pool.
|
||||
@@ -79,7 +89,7 @@ class VariablePool:
|
||||
v = factory.build_segment(value)
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self._variable_dictionary[selector[0]][hash_key] = v
|
||||
self.variable_dictionary[selector[0]][hash_key] = v
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""
|
||||
@@ -97,7 +107,7 @@ class VariablePool:
|
||||
if len(selector) < 2:
|
||||
raise ValueError("Invalid selector")
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||
|
||||
return value
|
||||
|
||||
@@ -118,7 +128,7 @@ class VariablePool:
|
||||
if len(selector) < 2:
|
||||
raise ValueError("Invalid selector")
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||
return value.to_object() if value else None
|
||||
|
||||
def remove(self, selector: Sequence[str], /):
|
||||
@@ -134,7 +144,19 @@ class VariablePool:
|
||||
if not selector:
|
||||
return
|
||||
if len(selector) == 1:
|
||||
self._variable_dictionary[selector[0]] = {}
|
||||
self.variable_dictionary[selector[0]] = {}
|
||||
return
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self._variable_dictionary[selector[0]].pop(hash_key, None)
|
||||
self.variable_dictionary[selector[0]].pop(hash_key, None)
|
||||
|
||||
def remove_node(self, node_id: str, /):
|
||||
"""
|
||||
Remove all variables associated with a given node id.
|
||||
|
||||
Args:
|
||||
node_id (str): The node id to remove.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.variable_dictionary.pop(node_id, None)
|
||||
|
@@ -66,8 +66,7 @@ class WorkflowRunState:
|
||||
self.variable_pool = variable_pool
|
||||
|
||||
self.total_tokens = 0
|
||||
self.workflow_nodes_and_results = []
|
||||
|
||||
self.current_iteration_state = None
|
||||
self.workflow_node_steps = 1
|
||||
self.workflow_node_runs = []
|
||||
self.workflow_node_runs = []
|
||||
self.current_iteration_state = None
|
||||
|
Reference in New Issue
Block a user