Fix: rerank switch and validation before run (#9416)
This commit is contained in:
@@ -23,7 +23,7 @@ import type { DataSet } from '@/models/datasets'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
|
||||
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
|
||||
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
@@ -34,6 +34,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
const startNodeId = startNode?.id
|
||||
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
|
||||
|
||||
const inputRef = useRef(inputs)
|
||||
|
||||
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
|
||||
const newInputs = produce(s, (draft) => {
|
||||
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
|
||||
@@ -43,13 +45,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
})
|
||||
// not work in pass to draft...
|
||||
doSetInputs(newInputs)
|
||||
inputRef.current = newInputs
|
||||
}, [doSetInputs])
|
||||
|
||||
const inputRef = useRef(inputs)
|
||||
useEffect(() => {
|
||||
inputRef.current = inputs
|
||||
}, [inputs])
|
||||
|
||||
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.query_variable_selector = newVar as ValueSelector
|
||||
@@ -63,9 +61,22 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
|
||||
|
||||
const {
|
||||
modelList: rerankModelList,
|
||||
defaultModel: rerankDefaultModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const {
|
||||
currentModel: currentRerankModel,
|
||||
} = useCurrentProviderAndModel(
|
||||
rerankModelList,
|
||||
rerankDefaultModel
|
||||
? {
|
||||
...rerankDefaultModel,
|
||||
provider: rerankDefaultModel.provider.provider,
|
||||
}
|
||||
: undefined,
|
||||
)
|
||||
|
||||
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
|
||||
const newInputs = produce(inputRef.current, (draft) => {
|
||||
if (!draft.single_retrieval_config) {
|
||||
@@ -110,7 +121,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
// set defaults models
|
||||
useEffect(() => {
|
||||
const inputs = inputRef.current
|
||||
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider)
|
||||
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel)
|
||||
return
|
||||
|
||||
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
|
||||
@@ -130,7 +141,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
draft.multiple_retrieval_config = {
|
||||
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
|
||||
@@ -138,6 +148,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
reranking_model: multipleRetrievalConfig?.reranking_model,
|
||||
reranking_mode: multipleRetrievalConfig?.reranking_mode,
|
||||
weights: multipleRetrievalConfig?.weights,
|
||||
reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined
|
||||
? multipleRetrievalConfig.reranking_enable
|
||||
: Boolean(currentRerankModel && rerankDefaultModel),
|
||||
}
|
||||
})
|
||||
setInputs(newInput)
|
||||
@@ -194,14 +207,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
const inputs = inputRef.current
|
||||
let query_variable_selector: ValueSelector = inputs.query_variable_selector
|
||||
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
|
||||
query_variable_selector = [startNodeId, 'sys.query']
|
||||
|
||||
setInputs({
|
||||
...inputs,
|
||||
query_variable_selector,
|
||||
})
|
||||
setInputs(produce(inputs, (draft) => {
|
||||
draft.query_variable_selector = query_variable_selector
|
||||
}))
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
|
Reference in New Issue
Block a user