Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -5,7 +5,6 @@ import type {
|
||||
BlockStatus,
|
||||
ChatPromptConfig,
|
||||
CitationConfig,
|
||||
CompletionParams,
|
||||
CompletionPromptConfig,
|
||||
ConversationHistoriesRole,
|
||||
DatasetConfigs,
|
||||
@@ -23,6 +22,7 @@ import type { DataSet } from '@/models/datasets'
|
||||
import type { VisionSettings } from '@/types/app'
|
||||
import { ModelModeType, RETRIEVE_TYPE, Resolution, TransferMethod } from '@/types/app'
|
||||
import { ANNOTATION_DEFAULT, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config'
|
||||
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
|
||||
type IDebugConfiguration = {
|
||||
appId: string
|
||||
@@ -72,8 +72,8 @@ type IDebugConfiguration = {
|
||||
query: string // user question
|
||||
setQuery: (query: string) => void
|
||||
// Belows are draft infos
|
||||
completionParams: CompletionParams
|
||||
setCompletionParams: (completionParams: CompletionParams) => void
|
||||
completionParams: FormValue
|
||||
setCompletionParams: (completionParams: FormValue) => void
|
||||
// model_config
|
||||
modelConfig: ModelConfig
|
||||
setModelConfig: (modelConfig: ModelConfig) => void
|
||||
|
@@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
|
||||
import type { Dispatch, SetStateAction } from 'react'
|
||||
import { useState } from 'react'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { createContext, useContext } from 'use-context-selector'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import AccountSetting from '@/app/components/header/account-setting'
|
||||
@@ -9,6 +9,12 @@ import ApiBasedExtensionModal from '@/app/components/header/account-setting/api-
|
||||
import ModerationSettingModal from '@/app/components/app/configuration/toolbox/moderation/moderation-setting-modal'
|
||||
import ExternalDataToolModal from '@/app/components/app/configuration/tools/external-data-tool-modal'
|
||||
import AnnotationFullModal from '@/app/components/billing/annotation-full/modal'
|
||||
import ModelModal from '@/app/components/header/account-setting/model-provider-page/model-modal'
|
||||
import type {
|
||||
ConfigurateMethodEnum,
|
||||
CustomConfigrationModelFixedFields,
|
||||
ModelProvider,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
|
||||
import Pricing from '@/app/components/billing/pricing'
|
||||
import type { ModerationConfig } from '@/models/debug'
|
||||
@@ -24,6 +30,11 @@ export type ModalState<T> = {
|
||||
onValidateBeforeSaveCallback?: (newPayload: T) => boolean
|
||||
}
|
||||
|
||||
export type ModelModalType = {
|
||||
currentProvider: ModelProvider
|
||||
currentConfigurateMethod: ConfigurateMethodEnum
|
||||
currentCustomConfigrationModelFixedFields?: CustomConfigrationModelFixedFields
|
||||
}
|
||||
const ModalContext = createContext<{
|
||||
setShowAccountSettingModal: Dispatch<SetStateAction<ModalState<string> | null>>
|
||||
setShowApiBasedExtensionModal: Dispatch<SetStateAction<ModalState<ApiBasedExtension> | null>>
|
||||
@@ -31,6 +42,7 @@ const ModalContext = createContext<{
|
||||
setShowExternalDataToolModal: Dispatch<SetStateAction<ModalState<ExternalDataTool> | null>>
|
||||
setShowPricingModal: Dispatch<SetStateAction<any>>
|
||||
setShowAnnotationFullModal: () => void
|
||||
setShowModelModal: Dispatch<SetStateAction<ModalState<ModelModalType> | null>>
|
||||
}>({
|
||||
setShowAccountSettingModal: () => { },
|
||||
setShowApiBasedExtensionModal: () => { },
|
||||
@@ -38,6 +50,7 @@ const ModalContext = createContext<{
|
||||
setShowExternalDataToolModal: () => { },
|
||||
setShowPricingModal: () => { },
|
||||
setShowAnnotationFullModal: () => { },
|
||||
setShowModelModal: () => {},
|
||||
})
|
||||
|
||||
export const useModalContext = () => useContext(ModalContext)
|
||||
@@ -52,6 +65,7 @@ export const ModalContextProvider = ({
|
||||
const [showApiBasedExtensionModal, setShowApiBasedExtensionModal] = useState<ModalState<ApiBasedExtension> | null>(null)
|
||||
const [showModerationSettingModal, setShowModerationSettingModal] = useState<ModalState<ModerationConfig> | null>(null)
|
||||
const [showExternalDataToolModal, setShowExternalDataToolModal] = useState<ModalState<ExternalDataTool> | null>(null)
|
||||
const [showModelModal, setShowModelModal] = useState<ModalState<ModelModalType> | null>(null)
|
||||
const searchParams = useSearchParams()
|
||||
const router = useRouter()
|
||||
const [showPricingModal, setShowPricingModal] = useState(searchParams.get('show-pricing') === '1')
|
||||
@@ -70,6 +84,20 @@ export const ModalContextProvider = ({
|
||||
showModerationSettingModal.onCancelCallback()
|
||||
}
|
||||
|
||||
const handleCancelModelModal = useCallback(() => {
|
||||
setShowModelModal(null)
|
||||
|
||||
if (showModelModal?.onCancelCallback)
|
||||
showModelModal.onCancelCallback()
|
||||
}, [showModelModal])
|
||||
|
||||
const handleSaveModelModal = useCallback(() => {
|
||||
if (showModelModal?.onSaveCallback)
|
||||
showModelModal.onSaveCallback(showModelModal.payload)
|
||||
|
||||
setShowModelModal(null)
|
||||
}, [showModelModal])
|
||||
|
||||
const handleSaveApiBasedExtension = (newApiBasedExtension: ApiBasedExtension) => {
|
||||
if (showApiBasedExtensionModal?.onSaveCallback)
|
||||
showApiBasedExtensionModal.onSaveCallback(newApiBasedExtension)
|
||||
@@ -106,6 +134,7 @@ export const ModalContextProvider = ({
|
||||
setShowExternalDataToolModal,
|
||||
setShowPricingModal: () => setShowPricingModal(true),
|
||||
setShowAnnotationFullModal: () => setShowAnnotationFullModal(true),
|
||||
setShowModelModal,
|
||||
}}>
|
||||
<>
|
||||
{children}
|
||||
@@ -165,6 +194,17 @@ export const ModalContextProvider = ({
|
||||
onHide={() => setShowAnnotationFullModal(false)} />
|
||||
)
|
||||
}
|
||||
{
|
||||
!!showModelModal && (
|
||||
<ModelModal
|
||||
provider={showModelModal.payload.currentProvider}
|
||||
configurateMethod={showModelModal.payload.currentConfigurateMethod}
|
||||
currentCustomConfigrationModelFixedFields={showModelModal.payload.currentCustomConfigrationModelFixedFields}
|
||||
onCancel={handleCancelModelModal}
|
||||
onSave={handleSaveModelModal}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</>
|
||||
</ModalContext.Provider>
|
||||
)
|
||||
|
@@ -2,10 +2,18 @@
|
||||
|
||||
import { createContext, useContext } from 'use-context-selector'
|
||||
import useSWR from 'swr'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { fetchDefaultModal, fetchModelList, fetchSupportRetrievalMethods } from '@/service/common'
|
||||
import { ModelFeature, ModelType } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
import type { BackendModel } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
import { useEffect, useMemo, useState } from 'react'
|
||||
import {
|
||||
fetchModelList,
|
||||
fetchModelProviders,
|
||||
fetchSupportRetrievalMethods,
|
||||
} from '@/service/common'
|
||||
import {
|
||||
ModelFeatureEnum,
|
||||
ModelStatusEnum,
|
||||
ModelTypeEnum,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { Model, ModelProvider } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { RETRIEVE_METHOD } from '@/types/app'
|
||||
import { Plan, type UsagePlanInfo } from '@/app/components/billing/type'
|
||||
import { fetchCurrentPlanInfo } from '@/service/billing'
|
||||
@@ -13,23 +21,11 @@ import { parseCurrentPlan } from '@/app/components/billing/utils'
|
||||
import { defaultPlan } from '@/app/components/billing/config'
|
||||
|
||||
const ProviderContext = createContext<{
|
||||
textGenerationModelList: BackendModel[]
|
||||
embeddingsModelList: BackendModel[]
|
||||
speech2textModelList: BackendModel[]
|
||||
rerankModelList: BackendModel[]
|
||||
agentThoughtModelList: BackendModel[]
|
||||
updateModelList: (type: ModelType) => void
|
||||
textGenerationDefaultModel?: BackendModel
|
||||
mutateTextGenerationDefaultModel: () => void
|
||||
embeddingsDefaultModel?: BackendModel
|
||||
isEmbeddingsDefaultModelValid: boolean
|
||||
mutateEmbeddingsDefaultModel: () => void
|
||||
speech2textDefaultModel?: BackendModel
|
||||
mutateSpeech2textDefaultModel: () => void
|
||||
rerankDefaultModel?: BackendModel
|
||||
isRerankDefaultModelVaild: boolean
|
||||
mutateRerankDefaultModel: () => void
|
||||
modelProviders: ModelProvider[]
|
||||
textGenerationModelList: Model[]
|
||||
agentThoughtModelList: Model[]
|
||||
supportRetrievalMethods: RETRIEVE_METHOD[]
|
||||
hasSettedApiKey: boolean
|
||||
plan: {
|
||||
type: Plan
|
||||
usage: UsagePlanInfo
|
||||
@@ -39,42 +35,30 @@ const ProviderContext = createContext<{
|
||||
enableBilling: boolean
|
||||
enableReplaceWebAppLogo: boolean
|
||||
}>({
|
||||
textGenerationModelList: [],
|
||||
embeddingsModelList: [],
|
||||
speech2textModelList: [],
|
||||
rerankModelList: [],
|
||||
agentThoughtModelList: [],
|
||||
updateModelList: () => { },
|
||||
textGenerationDefaultModel: undefined,
|
||||
mutateTextGenerationDefaultModel: () => { },
|
||||
speech2textDefaultModel: undefined,
|
||||
mutateSpeech2textDefaultModel: () => { },
|
||||
embeddingsDefaultModel: undefined,
|
||||
isEmbeddingsDefaultModelValid: false,
|
||||
mutateEmbeddingsDefaultModel: () => { },
|
||||
rerankDefaultModel: undefined,
|
||||
isRerankDefaultModelVaild: false,
|
||||
mutateRerankDefaultModel: () => { },
|
||||
supportRetrievalMethods: [],
|
||||
plan: {
|
||||
type: Plan.sandbox,
|
||||
usage: {
|
||||
vectorSpace: 32,
|
||||
buildApps: 12,
|
||||
teamMembers: 1,
|
||||
annotatedResponse: 1,
|
||||
},
|
||||
total: {
|
||||
vectorSpace: 200,
|
||||
buildApps: 50,
|
||||
teamMembers: 1,
|
||||
annotatedResponse: 10,
|
||||
},
|
||||
},
|
||||
isFetchedPlan: false,
|
||||
enableBilling: false,
|
||||
enableReplaceWebAppLogo: false,
|
||||
})
|
||||
modelProviders: [],
|
||||
textGenerationModelList: [],
|
||||
agentThoughtModelList: [],
|
||||
supportRetrievalMethods: [],
|
||||
hasSettedApiKey: true,
|
||||
plan: {
|
||||
type: Plan.sandbox,
|
||||
usage: {
|
||||
vectorSpace: 32,
|
||||
buildApps: 12,
|
||||
teamMembers: 1,
|
||||
annotatedResponse: 1,
|
||||
},
|
||||
total: {
|
||||
vectorSpace: 200,
|
||||
buildApps: 50,
|
||||
teamMembers: 1,
|
||||
annotatedResponse: 10,
|
||||
},
|
||||
},
|
||||
isFetchedPlan: false,
|
||||
enableBilling: false,
|
||||
enableReplaceWebAppLogo: false,
|
||||
})
|
||||
|
||||
export const useProviderContext = () => useContext(ProviderContext)
|
||||
|
||||
@@ -84,39 +68,30 @@ type ProviderContextProviderProps = {
|
||||
export const ProviderContextProvider = ({
|
||||
children,
|
||||
}: ProviderContextProviderProps) => {
|
||||
const { data: textGenerationDefaultModel, mutate: mutateTextGenerationDefaultModel } = useSWR('/workspaces/current/default-model?model_type=text-generation', fetchDefaultModal)
|
||||
const { data: embeddingsDefaultModel, mutate: mutateEmbeddingsDefaultModel } = useSWR('/workspaces/current/default-model?model_type=embeddings', fetchDefaultModal)
|
||||
const { data: speech2textDefaultModel, mutate: mutateSpeech2textDefaultModel } = useSWR('/workspaces/current/default-model?model_type=speech2text', fetchDefaultModal)
|
||||
const { data: rerankDefaultModel, mutate: mutateRerankDefaultModel } = useSWR('/workspaces/current/default-model?model_type=reranking', fetchDefaultModal)
|
||||
const fetchModelListUrlPrefix = '/workspaces/current/models/model-type/'
|
||||
const { data: textGenerationModelList, mutate: mutateTextGenerationModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.textGeneration}`, fetchModelList)
|
||||
const { data: embeddingsModelList, mutate: mutateEmbeddingsModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.embeddings}`, fetchModelList)
|
||||
const { data: speech2textModelList, mutate: mutateSpeech2textModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.speech2text}`, fetchModelList)
|
||||
const { data: rerankModelList, mutate: mutateRerankModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.reranking}`, fetchModelList)
|
||||
const { data: providersData } = useSWR('/workspaces/current/model-providers', fetchModelProviders)
|
||||
const fetchModelListUrlPrefix = '/workspaces/current/models/model-types/'
|
||||
const { data: textGenerationModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelTypeEnum.textGeneration}`, fetchModelList)
|
||||
const { data: supportRetrievalMethods } = useSWR('/datasets/retrieval-setting', fetchSupportRetrievalMethods)
|
||||
|
||||
const agentThoughtModelList = textGenerationModelList?.filter((item) => {
|
||||
return item.features?.includes(ModelFeature.agentThought)
|
||||
})
|
||||
const agentThoughtModelList = useMemo(() => {
|
||||
const result: Model[] = []
|
||||
if (textGenerationModelList?.data) {
|
||||
textGenerationModelList?.data.forEach((item) => {
|
||||
const agentThoughtModels = item.models.filter(model => model.features?.includes(ModelFeatureEnum.agentThought))
|
||||
|
||||
const isRerankDefaultModelVaild = !!rerankModelList?.find(
|
||||
item => item.model_name === rerankDefaultModel?.model_name && item.model_provider.provider_name === rerankDefaultModel?.model_provider.provider_name,
|
||||
)
|
||||
if (agentThoughtModels.length) {
|
||||
result.push({
|
||||
...item,
|
||||
models: agentThoughtModels,
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
const isEmbeddingsDefaultModelValid = !!embeddingsModelList?.find(
|
||||
item => item.model_name === embeddingsDefaultModel?.model_name && item.model_provider.provider_name === embeddingsDefaultModel?.model_provider.provider_name,
|
||||
)
|
||||
return result
|
||||
}
|
||||
|
||||
const updateModelList = (type: ModelType) => {
|
||||
if (type === ModelType.textGeneration)
|
||||
mutateTextGenerationModelList()
|
||||
if (type === ModelType.embeddings)
|
||||
mutateEmbeddingsModelList()
|
||||
if (type === ModelType.speech2text)
|
||||
mutateSpeech2textModelList()
|
||||
if (type === ModelType.reranking)
|
||||
mutateRerankModelList()
|
||||
}
|
||||
return []
|
||||
}, [textGenerationModelList])
|
||||
|
||||
const [plan, setPlan] = useState(defaultPlan)
|
||||
const [isFetchedPlan, setIsFetchedPlan] = useState(false)
|
||||
@@ -144,22 +119,10 @@ export const ProviderContextProvider = ({
|
||||
|
||||
return (
|
||||
<ProviderContext.Provider value={{
|
||||
textGenerationModelList: textGenerationModelList || [],
|
||||
embeddingsModelList: embeddingsModelList || [],
|
||||
speech2textModelList: speech2textModelList || [],
|
||||
rerankModelList: rerankModelList || [],
|
||||
agentThoughtModelList: agentThoughtModelList || [],
|
||||
updateModelList,
|
||||
textGenerationDefaultModel,
|
||||
mutateTextGenerationDefaultModel,
|
||||
embeddingsDefaultModel,
|
||||
mutateEmbeddingsDefaultModel,
|
||||
speech2textDefaultModel,
|
||||
mutateSpeech2textDefaultModel,
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelVaild,
|
||||
isEmbeddingsDefaultModelValid,
|
||||
mutateRerankDefaultModel,
|
||||
modelProviders: providersData?.data || [],
|
||||
textGenerationModelList: textGenerationModelList?.data || [],
|
||||
agentThoughtModelList,
|
||||
hasSettedApiKey: !!textGenerationModelList?.data.some(model => model.status === ModelStatusEnum.active),
|
||||
supportRetrievalMethods: supportRetrievalMethods?.retrieval_method || [],
|
||||
plan,
|
||||
isFetchedPlan,
|
||||
|
Reference in New Issue
Block a user