Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
import type { BackendModel } from '../../header/account-setting/model-page/declarations'
|
||||
import { RETRIEVE_METHOD, type RetrievalConfig } from '@/types/app'
|
||||
import type {
|
||||
DefaultModelResponse,
|
||||
Model,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
|
||||
export const isReRankModelSelected = ({
|
||||
rerankDefaultModel,
|
||||
@@ -8,15 +11,18 @@ export const isReRankModelSelected = ({
|
||||
rerankModelList,
|
||||
indexMethod,
|
||||
}: {
|
||||
rerankDefaultModel?: BackendModel
|
||||
rerankDefaultModel?: DefaultModelResponse
|
||||
isRerankDefaultModelVaild: boolean
|
||||
retrievalConfig: RetrievalConfig
|
||||
rerankModelList: BackendModel[]
|
||||
rerankModelList: Model[]
|
||||
indexMethod?: string
|
||||
}) => {
|
||||
const rerankModelSelected = (() => {
|
||||
if (retrievalConfig.reranking_model?.reranking_model_name)
|
||||
return !!rerankModelList.find(({ model_name }) => model_name === retrievalConfig.reranking_model?.reranking_model_name)
|
||||
if (retrievalConfig.reranking_model?.reranking_model_name) {
|
||||
const provider = rerankModelList.find(({ provider }) => provider === retrievalConfig.reranking_model?.reranking_provider_name)
|
||||
|
||||
return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name)
|
||||
}
|
||||
|
||||
if (isRerankDefaultModelVaild)
|
||||
return !!rerankDefaultModel
|
||||
@@ -39,7 +45,7 @@ export const ensureRerankModelSelected = ({
|
||||
indexMethod,
|
||||
retrievalConfig,
|
||||
}: {
|
||||
rerankDefaultModel: BackendModel
|
||||
rerankDefaultModel: DefaultModelResponse
|
||||
retrievalConfig: RetrievalConfig
|
||||
indexMethod?: string
|
||||
}) => {
|
||||
@@ -52,8 +58,8 @@ export const ensureRerankModelSelected = ({
|
||||
return {
|
||||
...retrievalConfig,
|
||||
reranking_model: {
|
||||
reranking_provider_name: rerankDefaultModel.model_provider.provider_name,
|
||||
reranking_model_name: rerankDefaultModel.model_name,
|
||||
reranking_provider_name: rerankDefaultModel.provider.provider,
|
||||
reranking_model_name: rerankDefaultModel.model,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@@ -9,6 +9,7 @@ import RadioCard from '@/app/components/base/radio-card'
|
||||
import { PatternRecognition, Semantic } from '@/app/components/base/icons/src/vender/solid/development'
|
||||
import { FileSearch02 } from '@/app/components/base/icons/src/vender/solid/files'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
|
||||
type Props = {
|
||||
value: RetrievalConfig
|
||||
@@ -20,14 +21,15 @@ const RetrievalMethodConfig: FC<Props> = ({
|
||||
onChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { supportRetrievalMethods, rerankDefaultModel } = useProviderContext()
|
||||
const { supportRetrievalMethods } = useProviderContext()
|
||||
const { data: rerankDefaultModel } = useDefaultModel(3)
|
||||
const value = (() => {
|
||||
if (!passValue.reranking_model.reranking_model_name) {
|
||||
return {
|
||||
...passValue,
|
||||
reranking_model: {
|
||||
reranking_provider_name: rerankDefaultModel?.model_provider.provider_name || '',
|
||||
reranking_model_name: rerankDefaultModel?.model_name || '',
|
||||
reranking_provider_name: rerankDefaultModel?.provider.provider || '',
|
||||
reranking_model_name: rerankDefaultModel?.model || '',
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@@ -9,10 +9,9 @@ import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import Tooltip from '@/app/components/base/tooltip-plus'
|
||||
import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import ModelSelector from '@/app/components/header/account-setting/model-page/model-selector'
|
||||
import { ModelType } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
import type { RetrievalConfig } from '@/types/app'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||
import { useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
|
||||
type Props = {
|
||||
type: RETRIEVE_METHOD
|
||||
@@ -29,8 +28,9 @@ const RetrievalParamConfig: FC<Props> = ({
|
||||
const canToggleRerankModalEnable = type !== RETRIEVE_METHOD.hybrid
|
||||
const isEconomical = type === RETRIEVE_METHOD.invertedIndex
|
||||
const {
|
||||
rerankDefaultModel,
|
||||
} = useProviderContext()
|
||||
defaultModel: rerankDefaultModel,
|
||||
modelList: rerankModelList,
|
||||
} = useModelListAndDefaultModel(3)
|
||||
|
||||
const rerankModel = (() => {
|
||||
if (value.reranking_model) {
|
||||
@@ -41,8 +41,8 @@ const RetrievalParamConfig: FC<Props> = ({
|
||||
}
|
||||
else if (rerankDefaultModel) {
|
||||
return {
|
||||
provider_name: rerankDefaultModel.model_provider.provider_name,
|
||||
model_name: rerankDefaultModel.model_name,
|
||||
provider_name: rerankDefaultModel.provider.provider,
|
||||
model_name: rerankDefaultModel.model,
|
||||
}
|
||||
}
|
||||
})()
|
||||
@@ -71,24 +71,21 @@ const RetrievalParamConfig: FC<Props> = ({
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<ModelSelector
|
||||
whenEmptyGoToSetting
|
||||
popClassName='!max-w-[100%] !w-full'
|
||||
value={rerankModel && { providerName: rerankModel.provider_name, modelName: rerankModel.model_name } as any}
|
||||
modelType={ModelType.reranking}
|
||||
readonly={!value.reranking_enable && type !== RETRIEVE_METHOD.hybrid}
|
||||
onChange={(v) => {
|
||||
onChange({
|
||||
...value,
|
||||
reranking_model: {
|
||||
reranking_provider_name: v.model_provider.provider_name,
|
||||
reranking_model_name: v.model_name,
|
||||
},
|
||||
})
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<ModelSelector
|
||||
triggerClassName={`${!value.reranking_enable && type !== RETRIEVE_METHOD.hybrid && '!opacity-60 !cursor-not-allowed'}`}
|
||||
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
|
||||
modelList={rerankModelList}
|
||||
readonly={!value.reranking_enable && type !== RETRIEVE_METHOD.hybrid}
|
||||
onSelect={(v) => {
|
||||
onChange({
|
||||
...value,
|
||||
reranking_model: {
|
||||
reranking_provider_name: v.provider,
|
||||
reranking_model_name: v.model,
|
||||
},
|
||||
})
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
@@ -11,8 +11,8 @@ import type { DataSet, FileItem, createDocumentResponse } from '@/models/dataset
|
||||
import { fetchDataSource } from '@/service/common'
|
||||
import { fetchDatasetDetail } from '@/service/datasets'
|
||||
import type { NotionPage } from '@/models/common'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useModalContext } from '@/context/modal-context'
|
||||
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
|
||||
type DatasetUpdateFormProps = {
|
||||
datasetId?: string
|
||||
@@ -28,7 +28,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
|
||||
const [fileList, setFiles] = useState<FileItem[]>([])
|
||||
const [result, setResult] = useState<createDocumentResponse | undefined>()
|
||||
const [hasError, setHasError] = useState(false)
|
||||
const { embeddingsDefaultModel } = useProviderContext()
|
||||
const { data: embeddingsDefaultModel } = useDefaultModel(2)
|
||||
|
||||
const [notionPages, setNotionPages] = useState<NotionPage[]>([])
|
||||
const updateNotionPages = (value: NotionPage[]) => {
|
||||
|
@@ -38,9 +38,9 @@ import { useDatasetDetailContext } from '@/context/dataset-detail'
|
||||
import I18n from '@/context/i18n'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
|
||||
type ValueOf<T> = T[keyof T]
|
||||
type StepTwoProps = {
|
||||
@@ -268,10 +268,10 @@ const StepTwo = ({
|
||||
}
|
||||
}
|
||||
const {
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelVaild,
|
||||
rerankModelList,
|
||||
} = useProviderContext()
|
||||
modelList: rerankModelList,
|
||||
defaultModel: rerankDefaultModel,
|
||||
currentModel: isRerankDefaultModelVaild,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(3)
|
||||
const getCreationParams = () => {
|
||||
let params
|
||||
if (isSetting) {
|
||||
@@ -289,7 +289,7 @@ const StepTwo = ({
|
||||
if (
|
||||
!isReRankModelSelected({
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelVaild,
|
||||
isRerankDefaultModelVaild: !!isRerankDefaultModelVaild,
|
||||
rerankModelList,
|
||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
||||
retrievalConfig,
|
||||
@@ -489,8 +489,8 @@ const StepTwo = ({
|
||||
search_method: RETRIEVE_METHOD.semantic,
|
||||
reranking_enable: false,
|
||||
reranking_model: {
|
||||
reranking_provider_name: rerankDefaultModel?.model_provider.provider_name,
|
||||
reranking_model_name: rerankDefaultModel?.model_name,
|
||||
reranking_provider_name: rerankDefaultModel?.provider.provider,
|
||||
reranking_model_name: rerankDefaultModel?.model,
|
||||
},
|
||||
top_k: 3,
|
||||
score_threshold_enabled: false,
|
||||
|
@@ -13,7 +13,7 @@ import Loading from '@/app/components/base/loading'
|
||||
import StepTwo from '@/app/components/datasets/create/step-two'
|
||||
import AccountSetting from '@/app/components/header/account-setting'
|
||||
import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
|
||||
type DocumentSettingsProps = {
|
||||
datasetId: string
|
||||
@@ -26,7 +26,7 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
|
||||
const [isShowSetAPIKey, { setTrue: showSetAPIKey, setFalse: hideSetAPIkey }] = useBoolean()
|
||||
const [hasError, setHasError] = useState(false)
|
||||
const { indexingTechnique, dataset } = useContext(DatasetDetailContext)
|
||||
const { embeddingsDefaultModel } = useProviderContext()
|
||||
const { data: embeddingsDefaultModel } = useDefaultModel(2)
|
||||
|
||||
const saveHandler = () => router.push(`/datasets/${datasetId}/documents/${documentId}`)
|
||||
|
||||
|
@@ -8,8 +8,8 @@ import type { RetrievalConfig } from '@/types/app'
|
||||
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
|
||||
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
|
||||
type Props = {
|
||||
indexMethod: string
|
||||
@@ -36,16 +36,16 @@ const ModifyRetrievalModal: FC<Props> = ({
|
||||
// }, ref)
|
||||
|
||||
const {
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelVaild,
|
||||
rerankModelList,
|
||||
} = useProviderContext()
|
||||
modelList: rerankModelList,
|
||||
defaultModel: rerankDefaultModel,
|
||||
currentModel: isRerankDefaultModelVaild,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(3)
|
||||
|
||||
const handleSave = () => {
|
||||
if (
|
||||
!isReRankModelSelected({
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelVaild,
|
||||
isRerankDefaultModelVaild: !!isRerankDefaultModelVaild,
|
||||
rerankModelList,
|
||||
retrievalConfig,
|
||||
indexMethod,
|
||||
|
@@ -15,14 +15,16 @@ import { ToastContext } from '@/app/components/base/toast'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { updateDatasetSetting } from '@/service/datasets'
|
||||
import type { DataSet, DataSetListResponse } from '@/models/datasets'
|
||||
import ModelSelector from '@/app/components/header/account-setting/model-page/model-selector'
|
||||
import type { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
import { ModelType } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
import DatasetDetailContext from '@/context/dataset-detail'
|
||||
import { type RetrievalConfig } from '@/types/app'
|
||||
import { useModalContext } from '@/context/modal-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
|
||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||
import {
|
||||
useModelList,
|
||||
useModelListAndDefaultModelAndCurrentProviderAndModel,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
|
||||
const rowClass = `
|
||||
flex justify-between py-4 flex-wrap gap-y-2
|
||||
`
|
||||
@@ -56,11 +58,13 @@ const Form = () => {
|
||||
const [permission, setPermission] = useState(currentDataset?.permission)
|
||||
const [indexMethod, setIndexMethod] = useState(currentDataset?.indexing_technique)
|
||||
const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict as RetrievalConfig)
|
||||
|
||||
const {
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelVaild,
|
||||
rerankModelList,
|
||||
} = useProviderContext()
|
||||
modelList: rerankModelList,
|
||||
defaultModel: rerankDefaultModel,
|
||||
currentModel: isRerankDefaultModelVaild,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(3)
|
||||
const { data: embeddingModelList } = useModelList(2)
|
||||
|
||||
const handleSave = async () => {
|
||||
if (loading)
|
||||
@@ -72,7 +76,7 @@ const Form = () => {
|
||||
if (
|
||||
!isReRankModelSelected({
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelVaild,
|
||||
isRerankDefaultModelVaild: !!isRerankDefaultModelVaild,
|
||||
rerankModelList,
|
||||
retrievalConfig,
|
||||
indexMethod,
|
||||
@@ -183,17 +187,15 @@ const Form = () => {
|
||||
<div>{t('datasetSettings.form.embeddingModel')}</div>
|
||||
</div>
|
||||
<div className='w-[480px]'>
|
||||
<div className='w-full h-9 rounded-lg bg-gray-100 opacity-60'>
|
||||
<ModelSelector
|
||||
readonly
|
||||
value={{
|
||||
providerName: currentDataset.embedding_model_provider as ProviderEnum,
|
||||
modelName: currentDataset.embedding_model,
|
||||
}}
|
||||
modelType={ModelType.embeddings}
|
||||
onChange={() => {}}
|
||||
/>
|
||||
</div>
|
||||
<ModelSelector
|
||||
readonly
|
||||
triggerClassName='!h-9 !cursor-not-allowed opacity-60'
|
||||
defaultModel={{
|
||||
provider: currentDataset.embedding_model_provider,
|
||||
model: currentDataset.embedding_model,
|
||||
}}
|
||||
modelList={embeddingModelList}
|
||||
/>
|
||||
<div className='mt-2 w-full text-xs leading-6 text-gray-500'>
|
||||
{t('datasetSettings.form.embeddingModelTip')}
|
||||
<span className='text-[#155eef] cursor-pointer' onClick={() => setShowAccountSettingModal({ payload: 'provider' })}>{t('datasetSettings.form.embeddingModelTipLink')}</span>
|
||||
|
Reference in New Issue
Block a user