diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py
index 24f1020c1..fcd8ed188 100644
--- a/api/controllers/console/app/annotation.py
+++ b/api/controllers/console/app/annotation.py
@@ -89,7 +89,7 @@ class AnnotationReplyActionStatusApi(Resource):
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
- raise ValueError("The job is not exist.")
+ raise ValueError("The job does not exist.")
job_status = cache_result.decode()
error_msg = ""
@@ -226,7 +226,7 @@ class AnnotationBatchImportStatusApi(Resource):
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
- raise ValueError("The job is not exist.")
+ raise ValueError("The job does not exist.")
job_status = cache_result.decode()
error_msg = ""
if job_status == "error":
diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py
index 1b38d0776..696aaa94d 100644
--- a/api/controllers/console/datasets/datasets_segments.py
+++ b/api/controllers/console/datasets/datasets_segments.py
@@ -398,7 +398,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
indexing_cache_key = "segment_batch_import_{}".format(job_id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
- raise ValueError("The job is not exist.")
+ raise ValueError("The job does not exist.")
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py
index 8d470b899..f087243a2 100644
--- a/api/controllers/service_api/dataset/dataset.py
+++ b/api/controllers/service_api/dataset/dataset.py
@@ -13,6 +13,7 @@ from fields.dataset_fields import dataset_detail_fields
from libs.login import current_user
from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
+from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
def _validate_name(name):
@@ -120,8 +121,11 @@ class DatasetListApi(DatasetApiResource):
nullable=True,
required=False,
)
- args = parser.parse_args()
+ parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
+ parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
+ parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
+ args = parser.parse_args()
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
@@ -133,6 +137,9 @@ class DatasetListApi(DatasetApiResource):
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
external_knowledge_id=args["external_knowledge_id"],
+ embedding_model_provider=args["embedding_model_provider"],
+ embedding_model_name=args["embedding_model"],
+ retrieval_model=RetrievalModel(**args["retrieval_model"]),
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py
index 4cc92847f..eec6afc9e 100644
--- a/api/controllers/service_api/dataset/document.py
+++ b/api/controllers/service_api/dataset/document.py
@@ -49,7 +49,9 @@ class DocumentAddByTextApi(DatasetApiResource):
parser.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
- parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
+ parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
+ parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
+ parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
dataset_id = str(dataset_id)
@@ -57,7 +59,7 @@ class DocumentAddByTextApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
- raise ValueError("Dataset is not exist.")
+ raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.")
@@ -114,7 +116,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
- raise ValueError("Dataset is not exist.")
+ raise ValueError("Dataset does not exist.")
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
@@ -172,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
- raise ValueError("Dataset is not exist.")
+ raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args.get("indexing_technique"):
raise ValueError("indexing_technique is required.")
@@ -239,7 +241,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
- raise ValueError("Dataset is not exist.")
+ raise ValueError("Dataset does not exist.")
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
@@ -303,7 +305,7 @@ class DocumentDeleteApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
- raise ValueError("Dataset is not exist.")
+ raise ValueError("Dataset does not exist.")
document = DocumentService.get_document(dataset.id, document_id)
diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
index 4efd90667..1e040f415 100644
--- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
+++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
@@ -444,7 +444,7 @@ class QdrantVectorFactory(AbstractVectorFactory):
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
- raise ValueError("Dataset Collection Bindings is not exist!")
+ raise ValueError("Dataset Collection Bindings does not exist!")
else:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index b019cf6b6..0301c8a58 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -169,6 +169,9 @@ class DatasetService:
provider: str = "vendor",
external_knowledge_api_id: Optional[str] = None,
external_knowledge_id: Optional[str] = None,
+ embedding_model_provider: Optional[str] = None,
+ embedding_model_name: Optional[str] = None,
+ retrieval_model: Optional[RetrievalModel] = None,
):
# check if dataset name already exists
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
@@ -176,9 +179,30 @@ class DatasetService:
embedding_model = None
if indexing_technique == "high_quality":
model_manager = ModelManager()
- embedding_model = model_manager.get_default_model_instance(
- tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
- )
+ if embedding_model_provider and embedding_model_name:
+ # check if embedding model setting is valid
+ DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model_name)
+ embedding_model = model_manager.get_model_instance(
+ tenant_id=tenant_id,
+ provider=embedding_model_provider,
+ model_type=ModelType.TEXT_EMBEDDING,
+ model=embedding_model_name,
+ )
+ else:
+ embedding_model = model_manager.get_default_model_instance(
+ tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
+ )
+ if retrieval_model and retrieval_model.reranking_model:
+ if (
+ retrieval_model.reranking_model.reranking_provider_name
+ and retrieval_model.reranking_model.reranking_model_name
+ ):
+ # check if reranking model setting is valid
+ DatasetService.check_embedding_model_setting(
+ tenant_id,
+ retrieval_model.reranking_model.reranking_provider_name,
+ retrieval_model.reranking_model.reranking_model_name,
+ )
dataset = Dataset(name=name, indexing_technique=indexing_technique)
# dataset = Dataset(name=name, provider=provider, config=config)
dataset.description = description
@@ -187,6 +211,7 @@ class DatasetService:
dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
dataset.embedding_model = embedding_model.model if embedding_model else None
+ dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
dataset.provider = provider
db.session.add(dataset)
@@ -923,11 +948,11 @@ class DocumentService:
"score_threshold_enabled": False,
}
- dataset.retrieval_model = (
- knowledge_config.retrieval_model.model_dump()
- if knowledge_config.retrieval_model
- else default_retrieval_model
- ) # type: ignore
+ dataset.retrieval_model = (
+ knowledge_config.retrieval_model.model_dump()
+ if knowledge_config.retrieval_model
+ else default_retrieval_model
+ ) # type: ignore
documents = []
if knowledge_config.original_document_id:
diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx
index 4faf26058..ab05f0c1f 100644
--- a/web/app/(commonLayout)/datasets/template/template.en.mdx
+++ b/web/app/(commonLayout)/datasets/template/template.en.mdx
@@ -314,6 +314,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
high_quality
High quality
- economy
Economy
search_method
(string) Search method
+ - hybrid_search
Hybrid search
+ - semantic_search
Semantic search
+ - full_text_search
Full-text search
+ - reranking_enable
(bool) Whether to enable reranking
+ - reranking_model
(object) Rerank model configuration
+ - reranking_provider_name
(string) Rerank model provider
+ - reranking_model_name
(string) Rerank model name
+ - top_k
(int) Number of results to return
+ - score_threshold_enabled
(bool) Whether to enable score threshold
+ - score_threshold
(float) Score threshold
+
search_method
(文字列) 検索方法
+ - hybrid_search
ハイブリッド検索
+ - semantic_search
セマンティック検索
+ - full_text_search
全文検索
+ - reranking_enable
(ブール値) リランキングを有効にするかどうか
+ - reranking_model
(オブジェクト) リランクモデルの設定
+ - reranking_provider_name
(文字列) リランクモデルのプロバイダ
+ - reranking_model_name
(文字列) リランクモデル名
+ - top_k
(整数) 返される結果の数
+ - score_threshold_enabled
(ブール値) スコア閾値を有効にするかどうか
+ - score_threshold
(浮動小数点数) スコア閾値
+ search_method
(string) 检索方法
+ - hybrid_search
混合检索
+ - semantic_search
语义检索
+ - full_text_search
全文检索
+ - reranking_enable
(bool) 是否开启rerank
+ - reranking_model
(object) Rerank 模型配置
+ - reranking_provider_name
(string) Rerank 模型的提供商
+ - reranking_model_name
(string) Rerank 模型的名称
+ - top_k
(int) 召回条数
+ - score_threshold_enabled
(bool)是否开启召回分数限制
+ - score_threshold
(float) 召回分数限制
+