Fix: rerank switch and validation before run (#9416)

This commit is contained in:
Yi Xiao
2024-10-17 14:26:38 +08:00
committed by GitHub
parent 4ac99ffe0e
commit 8a1f106c72
5 changed files with 61 additions and 84 deletions

View File

@@ -23,7 +23,7 @@ import type { DataSet } from '@/models/datasets'
import { fetchDatasets } from '@/service/datasets'
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
@@ -34,6 +34,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const startNodeId = startNode?.id
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
const inputRef = useRef(inputs)
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
const newInputs = produce(s, (draft) => {
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
@@ -43,13 +45,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
})
// not work in pass to draft...
doSetInputs(newInputs)
inputRef.current = newInputs
}, [doSetInputs])
const inputRef = useRef(inputs)
useEffect(() => {
inputRef.current = inputs
}, [inputs])
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
const newInputs = produce(inputs, (draft) => {
draft.query_variable_selector = newVar as ValueSelector
@@ -63,9 +61,22 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
const {
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
currentModel: currentRerankModel,
} = useCurrentProviderAndModel(
rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
)
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
const newInputs = produce(inputRef.current, (draft) => {
if (!draft.single_retrieval_config) {
@@ -110,7 +121,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
// set defaults models
useEffect(() => {
const inputs = inputRef.current
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider)
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel)
return
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
@@ -130,7 +141,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
}
}
}
const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = {
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
@@ -138,6 +148,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
reranking_model: multipleRetrievalConfig?.reranking_model,
reranking_mode: multipleRetrievalConfig?.reranking_mode,
weights: multipleRetrievalConfig?.weights,
reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined
? multipleRetrievalConfig.reranking_enable
: Boolean(currentRerankModel && rerankDefaultModel),
}
})
setInputs(newInput)
@@ -194,14 +207,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
}, [])
useEffect(() => {
const inputs = inputRef.current
let query_variable_selector: ValueSelector = inputs.query_variable_selector
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
query_variable_selector = [startNodeId, 'sys.query']
setInputs({
...inputs,
query_variable_selector,
})
setInputs(produce(inputs, (draft) => {
draft.query_variable_selector = query_variable_selector
}))
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])