feat: n to 1 retrieval legacy (#6554)
This commit is contained in:
@@ -1,4 +1,9 @@
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react'
|
||||
import produce from 'immer'
|
||||
import { isEqual } from 'lodash-es'
|
||||
import type { ValueSelector, Var } from '../../types'
|
||||
@@ -8,6 +13,10 @@ import {
|
||||
useWorkflow,
|
||||
} from '../../hooks'
|
||||
import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types'
|
||||
import {
|
||||
getMultipleRetrievalConfig,
|
||||
getSelectedDatasetsMode,
|
||||
} from './utils'
|
||||
import { RETRIEVE_TYPE } from '@/types/app'
|
||||
import { DATASET_DEFAULT } from '@/config'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
@@ -126,34 +135,20 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
draft.multiple_retrieval_config = {
|
||||
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
|
||||
score_threshold: multipleRetrievalConfig?.score_threshold,
|
||||
reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay
|
||||
? undefined
|
||||
: (!multipleRetrievalConfig?.reranking_model?.provider
|
||||
? {
|
||||
provider: rerankDefaultModel?.provider?.provider || '',
|
||||
model: rerankDefaultModel?.model || '',
|
||||
}
|
||||
: multipleRetrievalConfig?.reranking_model),
|
||||
reranking_model: multipleRetrievalConfig?.reranking_model,
|
||||
}
|
||||
})
|
||||
setInputs(newInput)
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [currentProvider?.provider, currentModel, rerankDefaultModel])
|
||||
|
||||
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
|
||||
const [rerankModelOpen, setRerankModelOpen] = useState(false)
|
||||
const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.retrieval_mode = newMode
|
||||
if (newMode === RETRIEVE_TYPE.multiWay) {
|
||||
draft.multiple_retrieval_config = {
|
||||
top_k: draft.multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k,
|
||||
score_threshold: draft.multiple_retrieval_config?.score_threshold,
|
||||
reranking_model: !draft.multiple_retrieval_config?.reranking_model?.provider
|
||||
? {
|
||||
provider: rerankDefaultModel?.provider?.provider || '',
|
||||
model: rerankDefaultModel?.model || '',
|
||||
}
|
||||
: draft.multiple_retrieval_config?.reranking_model,
|
||||
}
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets)
|
||||
}
|
||||
else {
|
||||
const hasSetModel = draft.single_retrieval_config?.model?.provider
|
||||
@@ -170,17 +165,16 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, rerankDefaultModel?.model, rerankDefaultModel?.provider?.provider, setInputs])
|
||||
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets])
|
||||
|
||||
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.multiple_retrieval_config = newConfig
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets)
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
}, [inputs, setInputs, selectedDatasets])
|
||||
|
||||
// datasets
|
||||
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
const inputs = inputRef.current
|
||||
@@ -210,12 +204,25 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
}, [])
|
||||
|
||||
const handleOnDatasetsChange = useCallback((newDatasets: DataSet[]) => {
|
||||
const {
|
||||
allEconomic,
|
||||
mixtureHighQualityAndEconomic,
|
||||
inconsistentEmbeddingModel,
|
||||
} = getSelectedDatasetsMode(newDatasets)
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.dataset_ids = newDatasets.map(d => d.id)
|
||||
|
||||
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets)
|
||||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
setSelectedDatasets(newDatasets)
|
||||
}, [inputs, setInputs])
|
||||
|
||||
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel)
|
||||
setRerankModelOpen(true)
|
||||
}, [inputs, setInputs, payload.retrieval_mode])
|
||||
|
||||
const filterVar = useCallback((varPayload: Var) => {
|
||||
return varPayload.type === VarType.string
|
||||
@@ -266,6 +273,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
query,
|
||||
setQuery,
|
||||
runResult,
|
||||
rerankModelOpen,
|
||||
setRerankModelOpen,
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user