fix(api): Some params were ignored when creating empty Datasets through API (#17932)

This commit is contained in:
Jasonfish
2025-04-14 10:24:01 +08:00
committed by GitHub
parent 4aecc9f090
commit 1f722cde22
9 changed files with 115 additions and 20 deletions

View File

@@ -169,6 +169,9 @@ class DatasetService:
provider: str = "vendor",
external_knowledge_api_id: Optional[str] = None,
external_knowledge_id: Optional[str] = None,
embedding_model_provider: Optional[str] = None,
embedding_model_name: Optional[str] = None,
retrieval_model: Optional[RetrievalModel] = None,
):
# check if dataset name already exists
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
@@ -176,9 +179,30 @@ class DatasetService:
embedding_model = None
if indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
if embedding_model_provider and embedding_model_name:
# check if embedding model setting is valid
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model_name)
embedding_model = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=embedding_model_name,
)
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
if retrieval_model and retrieval_model.reranking_model:
if (
retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
# check if reranking model setting is valid
DatasetService.check_embedding_model_setting(
tenant_id,
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
dataset = Dataset(name=name, indexing_technique=indexing_technique)
# dataset = Dataset(name=name, provider=provider, config=config)
dataset.description = description
@@ -187,6 +211,7 @@ class DatasetService:
dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
dataset.embedding_model = embedding_model.model if embedding_model else None
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
dataset.provider = provider
db.session.add(dataset)
@@ -923,11 +948,11 @@ class DocumentService:
"score_threshold_enabled": False,
}
dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
documents = []
if knowledge_config.original_document_id: