Feat: rerank model verification in front end (#9271)

This commit is contained in:
Yi Xiao
2024-10-12 21:24:43 +08:00
committed by GitHub
parent c6b74daa0a
commit 793205afc5
6 changed files with 159 additions and 24 deletions

View File

@@ -1,17 +1,25 @@
import { useCallback } from 'react'
import { useStoreApi } from 'reactflow'
import { useTranslation } from 'react-i18next'
import { useWorkflowStore } from '../store'
import {
BlockEnum,
WorkflowRunningStatus,
} from '../types'
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
import type { Node } from '../types'
import { useWorkflow } from './use-workflow'
import {
useIsChatMode,
useNodesSyncDraft,
useWorkflowInteractions,
useWorkflowRun,
} from './index'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { useFeaturesStore } from '@/app/components/base/features/hooks'
import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default'
import Toast from '@/app/components/base/toast'
export const useWorkflowStartRun = () => {
const store = useStoreApi()
@@ -20,7 +28,26 @@ export const useWorkflowStartRun = () => {
const isChatMode = useIsChatMode()
const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions()
const { handleRun } = useWorkflowRun()
const { isFromStartNode } = useWorkflow()
const { doSyncWorkflowDraft } = useNodesSyncDraft()
const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault
const { t } = useTranslation()
const {
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
currentModel,
} = useCurrentProviderAndModel(
rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
)
const handleWorkflowStartRunInWorkflow = useCallback(async () => {
const {
@@ -33,6 +60,9 @@ export const useWorkflowStartRun = () => {
const { getNodes } = store.getState()
const nodes = getNodes()
const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
const knowledgeRetrievalNodes = nodes.filter((node: Node<KnowledgeRetrievalNodeType>) =>
node.data.type === BlockEnum.KnowledgeRetrieval,
)
const startVariables = startNode?.data.variables || []
const fileSettings = featuresStore!.getState().features.file
const {
@@ -42,6 +72,31 @@ export const useWorkflowStartRun = () => {
setShowEnvPanel,
} = workflowStore.getState()
if (knowledgeRetrievalNodes.length > 0) {
for (const node of knowledgeRetrievalNodes) {
if (isFromStartNode(node.id)) {
const res = checkKnowledgeRetrievalValid(node.data, t)
if (!res.isValid || !currentModel || !rerankDefaultModel) {
const errorMessage = res.errorMessage
if (errorMessage) {
Toast.notify({
type: 'error',
message: errorMessage,
})
return false
}
else {
Toast.notify({
type: 'error',
message: t('appDebug.datasetConfig.rerankModelRequired'),
})
return false
}
}
}
}
}
setShowEnvPanel(false)
if (showDebugAndPreviewPanel) {

View File

@@ -235,6 +235,33 @@ export const useWorkflow = () => {
return nodes.filter(node => node.parentId === nodeId)
}, [store])
const isFromStartNode = useCallback((nodeId: string) => {
const { getNodes } = store.getState()
const nodes = getNodes()
const currentNode = nodes.find(node => node.id === nodeId)
if (!currentNode)
return false
if (currentNode.data.type === BlockEnum.Start)
return true
const checkPreviousNodes = (node: Node) => {
const previousNodes = getBeforeNodeById(node.id)
for (const prevNode of previousNodes) {
if (prevNode.data.type === BlockEnum.Start)
return true
if (checkPreviousNodes(prevNode))
return true
}
return false
}
return checkPreviousNodes(currentNode)
}, [store, getBeforeNodeById])
const handleOutVarRenameChange = useCallback((nodeId: string, oldValeSelector: ValueSelector, newVarSelector: ValueSelector) => {
const { getNodes, setNodes } = store.getState()
const afterNodes = getAfterNodesInSameBranch(nodeId)
@@ -389,6 +416,7 @@ export const useWorkflow = () => {
checkParallelLimit,
checkNestedParallelLimit,
isValidConnection,
isFromStartNode,
formatTimeFromNow,
getNode,
getBeforeNodeById,