Model Runtime (#1858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
takatost
2024-01-02 23:42:00 +08:00
committed by GitHub
parent e91dd28a76
commit d069c668f8
807 changed files with 171310 additions and 23806 deletions

View File

@@ -1,5 +1,8 @@
import type { BackendModel } from '../../header/account-setting/model-page/declarations'
import { RETRIEVE_METHOD, type RetrievalConfig } from '@/types/app'
import type {
DefaultModelResponse,
Model,
} from '@/app/components/header/account-setting/model-provider-page/declarations'
export const isReRankModelSelected = ({
rerankDefaultModel,
@@ -8,15 +11,18 @@ export const isReRankModelSelected = ({
rerankModelList,
indexMethod,
}: {
rerankDefaultModel?: BackendModel
rerankDefaultModel?: DefaultModelResponse
isRerankDefaultModelVaild: boolean
retrievalConfig: RetrievalConfig
rerankModelList: BackendModel[]
rerankModelList: Model[]
indexMethod?: string
}) => {
const rerankModelSelected = (() => {
if (retrievalConfig.reranking_model?.reranking_model_name)
return !!rerankModelList.find(({ model_name }) => model_name === retrievalConfig.reranking_model?.reranking_model_name)
if (retrievalConfig.reranking_model?.reranking_model_name) {
const provider = rerankModelList.find(({ provider }) => provider === retrievalConfig.reranking_model?.reranking_provider_name)
return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name)
}
if (isRerankDefaultModelVaild)
return !!rerankDefaultModel
@@ -39,7 +45,7 @@ export const ensureRerankModelSelected = ({
indexMethod,
retrievalConfig,
}: {
rerankDefaultModel: BackendModel
rerankDefaultModel: DefaultModelResponse
retrievalConfig: RetrievalConfig
indexMethod?: string
}) => {
@@ -52,8 +58,8 @@ export const ensureRerankModelSelected = ({
return {
...retrievalConfig,
reranking_model: {
reranking_provider_name: rerankDefaultModel.model_provider.provider_name,
reranking_model_name: rerankDefaultModel.model_name,
reranking_provider_name: rerankDefaultModel.provider.provider,
reranking_model_name: rerankDefaultModel.model,
},
}
}

View File

@@ -9,6 +9,7 @@ import RadioCard from '@/app/components/base/radio-card'
import { PatternRecognition, Semantic } from '@/app/components/base/icons/src/vender/solid/development'
import { FileSearch02 } from '@/app/components/base/icons/src/vender/solid/files'
import { useProviderContext } from '@/context/provider-context'
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
type Props = {
value: RetrievalConfig
@@ -20,14 +21,15 @@ const RetrievalMethodConfig: FC<Props> = ({
onChange,
}) => {
const { t } = useTranslation()
const { supportRetrievalMethods, rerankDefaultModel } = useProviderContext()
const { supportRetrievalMethods } = useProviderContext()
const { data: rerankDefaultModel } = useDefaultModel(3)
const value = (() => {
if (!passValue.reranking_model.reranking_model_name) {
return {
...passValue,
reranking_model: {
reranking_provider_name: rerankDefaultModel?.model_provider.provider_name || '',
reranking_model_name: rerankDefaultModel?.model_name || '',
reranking_provider_name: rerankDefaultModel?.provider.provider || '',
reranking_model_name: rerankDefaultModel?.model || '',
},
}
}

View File

@@ -9,10 +9,9 @@ import { RETRIEVE_METHOD } from '@/types/app'
import Switch from '@/app/components/base/switch'
import Tooltip from '@/app/components/base/tooltip-plus'
import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general'
import ModelSelector from '@/app/components/header/account-setting/model-page/model-selector'
import { ModelType } from '@/app/components/header/account-setting/model-page/declarations'
import type { RetrievalConfig } from '@/types/app'
import { useProviderContext } from '@/context/provider-context'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
import { useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
type Props = {
type: RETRIEVE_METHOD
@@ -29,8 +28,9 @@ const RetrievalParamConfig: FC<Props> = ({
const canToggleRerankModalEnable = type !== RETRIEVE_METHOD.hybrid
const isEconomical = type === RETRIEVE_METHOD.invertedIndex
const {
rerankDefaultModel,
} = useProviderContext()
defaultModel: rerankDefaultModel,
modelList: rerankModelList,
} = useModelListAndDefaultModel(3)
const rerankModel = (() => {
if (value.reranking_model) {
@@ -41,8 +41,8 @@ const RetrievalParamConfig: FC<Props> = ({
}
else if (rerankDefaultModel) {
return {
provider_name: rerankDefaultModel.model_provider.provider_name,
model_name: rerankDefaultModel.model_name,
provider_name: rerankDefaultModel.provider.provider,
model_name: rerankDefaultModel.model,
}
}
})()
@@ -71,24 +71,21 @@ const RetrievalParamConfig: FC<Props> = ({
</Tooltip>
</div>
</div>
<div>
<ModelSelector
whenEmptyGoToSetting
popClassName='!max-w-[100%] !w-full'
value={rerankModel && { providerName: rerankModel.provider_name, modelName: rerankModel.model_name } as any}
modelType={ModelType.reranking}
readonly={!value.reranking_enable && type !== RETRIEVE_METHOD.hybrid}
onChange={(v) => {
onChange({
...value,
reranking_model: {
reranking_provider_name: v.model_provider.provider_name,
reranking_model_name: v.model_name,
},
})
}}
/>
</div>
<ModelSelector
triggerClassName={`${!value.reranking_enable && type !== RETRIEVE_METHOD.hybrid && '!opacity-60 !cursor-not-allowed'}`}
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
modelList={rerankModelList}
readonly={!value.reranking_enable && type !== RETRIEVE_METHOD.hybrid}
onSelect={(v) => {
onChange({
...value,
reranking_model: {
reranking_provider_name: v.provider,
reranking_model_name: v.model,
},
})
}}
/>
</div>
)}

View File

@@ -11,8 +11,8 @@ import type { DataSet, FileItem, createDocumentResponse } from '@/models/dataset
import { fetchDataSource } from '@/service/common'
import { fetchDatasetDetail } from '@/service/datasets'
import type { NotionPage } from '@/models/common'
import { useProviderContext } from '@/context/provider-context'
import { useModalContext } from '@/context/modal-context'
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
type DatasetUpdateFormProps = {
datasetId?: string
@@ -28,7 +28,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
const [fileList, setFiles] = useState<FileItem[]>([])
const [result, setResult] = useState<createDocumentResponse | undefined>()
const [hasError, setHasError] = useState(false)
const { embeddingsDefaultModel } = useProviderContext()
const { data: embeddingsDefaultModel } = useDefaultModel(2)
const [notionPages, setNotionPages] = useState<NotionPage[]>([])
const updateNotionPages = (value: NotionPage[]) => {

View File

@@ -38,9 +38,9 @@ import { useDatasetDetailContext } from '@/context/dataset-detail'
import I18n from '@/context/i18n'
import { IS_CE_EDITION } from '@/config'
import { RETRIEVE_METHOD } from '@/types/app'
import { useProviderContext } from '@/context/provider-context'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import Tooltip from '@/app/components/base/tooltip'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
type ValueOf<T> = T[keyof T]
type StepTwoProps = {
@@ -268,10 +268,10 @@ const StepTwo = ({
}
}
const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelVaild,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(3)
const getCreationParams = () => {
let params
if (isSetting) {
@@ -289,7 +289,7 @@ const StepTwo = ({
if (
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
isRerankDefaultModelVaild: !!isRerankDefaultModelVaild,
rerankModelList,
// eslint-disable-next-line @typescript-eslint/no-use-before-define
retrievalConfig,
@@ -489,8 +489,8 @@ const StepTwo = ({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: rerankDefaultModel?.model_provider.provider_name,
reranking_model_name: rerankDefaultModel?.model_name,
reranking_provider_name: rerankDefaultModel?.provider.provider,
reranking_model_name: rerankDefaultModel?.model,
},
top_k: 3,
score_threshold_enabled: false,

View File

@@ -13,7 +13,7 @@ import Loading from '@/app/components/base/loading'
import StepTwo from '@/app/components/datasets/create/step-two'
import AccountSetting from '@/app/components/header/account-setting'
import AppUnavailable from '@/app/components/base/app-unavailable'
import { useProviderContext } from '@/context/provider-context'
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
type DocumentSettingsProps = {
datasetId: string
@@ -26,7 +26,7 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
const [isShowSetAPIKey, { setTrue: showSetAPIKey, setFalse: hideSetAPIkey }] = useBoolean()
const [hasError, setHasError] = useState(false)
const { indexingTechnique, dataset } = useContext(DatasetDetailContext)
const { embeddingsDefaultModel } = useProviderContext()
const { data: embeddingsDefaultModel } = useDefaultModel(2)
const saveHandler = () => router.push(`/datasets/${datasetId}/documents/${documentId}`)

View File

@@ -8,8 +8,8 @@ import type { RetrievalConfig } from '@/types/app'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
import Button from '@/app/components/base/button'
import { useProviderContext } from '@/context/provider-context'
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
type Props = {
indexMethod: string
@@ -36,16 +36,16 @@ const ModifyRetrievalModal: FC<Props> = ({
// }, ref)
const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelVaild,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(3)
const handleSave = () => {
if (
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
isRerankDefaultModelVaild: !!isRerankDefaultModelVaild,
rerankModelList,
retrievalConfig,
indexMethod,

View File

@@ -15,14 +15,16 @@ import { ToastContext } from '@/app/components/base/toast'
import Button from '@/app/components/base/button'
import { updateDatasetSetting } from '@/service/datasets'
import type { DataSet, DataSetListResponse } from '@/models/datasets'
import ModelSelector from '@/app/components/header/account-setting/model-page/model-selector'
import type { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
import { ModelType } from '@/app/components/header/account-setting/model-page/declarations'
import DatasetDetailContext from '@/context/dataset-detail'
import { type RetrievalConfig } from '@/types/app'
import { useModalContext } from '@/context/modal-context'
import { useProviderContext } from '@/context/provider-context'
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
import {
useModelList,
useModelListAndDefaultModelAndCurrentProviderAndModel,
} from '@/app/components/header/account-setting/model-provider-page/hooks'
const rowClass = `
flex justify-between py-4 flex-wrap gap-y-2
`
@@ -56,11 +58,13 @@ const Form = () => {
const [permission, setPermission] = useState(currentDataset?.permission)
const [indexMethod, setIndexMethod] = useState(currentDataset?.indexing_technique)
const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict as RetrievalConfig)
const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelVaild,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(3)
const { data: embeddingModelList } = useModelList(2)
const handleSave = async () => {
if (loading)
@@ -72,7 +76,7 @@ const Form = () => {
if (
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
isRerankDefaultModelVaild: !!isRerankDefaultModelVaild,
rerankModelList,
retrievalConfig,
indexMethod,
@@ -183,17 +187,15 @@ const Form = () => {
<div>{t('datasetSettings.form.embeddingModel')}</div>
</div>
<div className='w-[480px]'>
<div className='w-full h-9 rounded-lg bg-gray-100 opacity-60'>
<ModelSelector
readonly
value={{
providerName: currentDataset.embedding_model_provider as ProviderEnum,
modelName: currentDataset.embedding_model,
}}
modelType={ModelType.embeddings}
onChange={() => {}}
/>
</div>
<ModelSelector
readonly
triggerClassName='!h-9 !cursor-not-allowed opacity-60'
defaultModel={{
provider: currentDataset.embedding_model_provider,
model: currentDataset.embedding_model,
}}
modelList={embeddingModelList}
/>
<div className='mt-2 w-full text-xs leading-6 text-gray-500'>
{t('datasetSettings.form.embeddingModelTip')}
<span className='text-[#155eef] cursor-pointer' onClick={() => setShowAccountSettingModal({ payload: 'provider' })}>{t('datasetSettings.form.embeddingModelTipLink')}</span>