feat: Parallel Execution of Nodes in Workflows (#8192)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
takatost
2024-09-10 15:23:16 +08:00
committed by GitHub
parent 5da0182800
commit dabfd74622
156 changed files with 11158 additions and 5605 deletions

View File

@@ -6,6 +6,7 @@ import {
} from 'react'
import dayjs from 'dayjs'
import { uniqBy } from 'lodash-es'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import {
getIncomers,
@@ -29,6 +30,11 @@ import {
useWorkflowStore,
} from '../store'
import {
getParallelInfo,
} from '../utils'
import {
PARALLEL_DEPTH_LIMIT,
PARALLEL_LIMIT,
SUPPORT_OUTPUT_VARS_NODE,
} from '../constants'
import { CUSTOM_NOTE_NODE } from '../note-node/constants'
@@ -50,6 +56,7 @@ import {
} from '@/service/tools'
import I18n from '@/context/i18n'
import { CollectionType } from '@/app/components/tools/types'
import { CUSTOM_ITERATION_START_NODE } from '@/app/components/workflow/nodes/iteration-start/constants'
export const useIsChatMode = () => {
const appDetail = useAppStore(s => s.appDetail)
@@ -58,6 +65,7 @@ export const useIsChatMode = () => {
}
export const useWorkflow = () => {
const { t } = useTranslation()
const { locale } = useContext(I18n)
const store = useStoreApi()
const workflowStore = useWorkflowStore()
@@ -77,7 +85,7 @@ export const useWorkflow = () => {
const currentNode = nodes.find(node => node.id === nodeId)
if (currentNode?.parentId)
startNode = nodes.find(node => node.parentId === currentNode.parentId && node.data.isIterationStart)
startNode = nodes.find(node => node.parentId === currentNode.parentId && node.type === CUSTOM_ITERATION_START_NODE)
if (!startNode)
return []
@@ -275,6 +283,45 @@ export const useWorkflow = () => {
return isUsed
}, [isVarUsedInNodes])
const checkParallelLimit = useCallback((nodeId: string) => {
const {
getNodes,
edges,
} = store.getState()
const nodes = getNodes()
const currentNode = nodes.find(node => node.id === nodeId)!
const sourceNodeOutgoers = getOutgoers(currentNode, nodes, edges)
if (sourceNodeOutgoers.length > PARALLEL_LIMIT - 1) {
const { setShowTips } = workflowStore.getState()
setShowTips(t('workflow.common.parallelTip.limit', { num: PARALLEL_LIMIT }))
return false
}
return true
}, [store, workflowStore, t])
const checkNestedParallelLimit = useCallback((nodes: Node[], edges: Edge[], parentNodeId?: string) => {
const {
parallelList,
hasAbnormalEdges,
} = getParallelInfo(nodes, edges, parentNodeId)
if (hasAbnormalEdges)
return false
for (let i = 0; i < parallelList.length; i++) {
const parallel = parallelList[i]
if (parallel.depth > PARALLEL_DEPTH_LIMIT) {
const { setShowTips } = workflowStore.getState()
setShowTips(t('workflow.common.parallelTip.depthLimit', { num: PARALLEL_DEPTH_LIMIT }))
return false
}
}
return true
}, [t, workflowStore])
const isValidConnection = useCallback(({ source, target }: Connection) => {
const {
edges,
@@ -284,12 +331,15 @@ export const useWorkflow = () => {
const sourceNode: Node = nodes.find(node => node.id === source)!
const targetNode: Node = nodes.find(node => node.id === target)!
if (targetNode.data.isIterationStart)
if (!checkParallelLimit(source!))
return false
if (sourceNode.type === CUSTOM_NOTE_NODE || targetNode.type === CUSTOM_NOTE_NODE)
return false
if (sourceNode.parentId !== targetNode.parentId)
return false
if (sourceNode && targetNode) {
const sourceNodeAvailableNextNodes = nodesExtraData[sourceNode.data.type].availableNextNodes
const targetNodeAvailablePrevNodes = [...nodesExtraData[targetNode.data.type].availablePrevNodes, BlockEnum.Start]
@@ -316,7 +366,7 @@ export const useWorkflow = () => {
}
return !hasCycle(targetNode)
}, [store, nodesExtraData])
}, [store, nodesExtraData, checkParallelLimit])
const formatTimeFromNow = useCallback((time: number) => {
return dayjs(time).locale(locale === 'zh-Hans' ? 'zh-cn' : locale).fromNow()
@@ -339,6 +389,8 @@ export const useWorkflow = () => {
isVarUsedInNodes,
removeUsedVarInNodes,
isNodeVarsUsedInNodes,
checkParallelLimit,
checkNestedParallelLimit,
isValidConnection,
formatTimeFromNow,
getNode,