fix:hard-coded top-k fallback issue. (#24879)
This commit is contained in:
@@ -24,7 +24,7 @@ default_retrieval_model = {
|
|||||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
"reranking_enable": False,
|
"reranking_enable": False,
|
||||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||||
"top_k": 2,
|
"top_k": 4,
|
||||||
"score_threshold_enabled": False,
|
"score_threshold_enabled": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -304,7 +304,7 @@ class CouchbaseVector(BaseVector):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
top_k = kwargs.get("top_k", 2)
|
top_k = kwargs.get("top_k", 4)
|
||||||
try:
|
try:
|
||||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
||||||
search_iter = self._scope.search(
|
search_iter = self._scope.search(
|
||||||
|
@@ -65,7 +65,7 @@ default_retrieval_model: dict[str, Any] = {
|
|||||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
"reranking_enable": False,
|
"reranking_enable": False,
|
||||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||||
"top_k": 2,
|
"top_k": 4,
|
||||||
"score_threshold_enabled": False,
|
"score_threshold_enabled": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -647,7 +647,7 @@ class DatasetRetrieval:
|
|||||||
retrieval_method=retrieval_model["search_method"],
|
retrieval_method=retrieval_model["search_method"],
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
top_k=retrieval_model.get("top_k") or 2,
|
top_k=retrieval_model.get("top_k") or 4,
|
||||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||||
if retrieval_model["score_threshold_enabled"]
|
if retrieval_model["score_threshold_enabled"]
|
||||||
else 0.0,
|
else 0.0,
|
||||||
@@ -743,7 +743,7 @@ class DatasetRetrieval:
|
|||||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
top_k=retrieve_config.top_k or 2,
|
top_k=retrieve_config.top_k or 4,
|
||||||
score_threshold=retrieve_config.score_threshold,
|
score_threshold=retrieve_config.score_threshold,
|
||||||
hit_callbacks=[hit_callback],
|
hit_callbacks=[hit_callback],
|
||||||
return_resource=return_resource,
|
return_resource=return_resource,
|
||||||
|
@@ -181,7 +181,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
retrieval_method="keyword_search",
|
retrieval_method="keyword_search",
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
top_k=retrieval_model.get("top_k") or 2,
|
top_k=retrieval_model.get("top_k") or 4,
|
||||||
)
|
)
|
||||||
if documents:
|
if documents:
|
||||||
all_documents.extend(documents)
|
all_documents.extend(documents)
|
||||||
@@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
retrieval_method=retrieval_model["search_method"],
|
retrieval_method=retrieval_model["search_method"],
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
top_k=retrieval_model.get("top_k") or 2,
|
top_k=retrieval_model.get("top_k") or 4,
|
||||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||||
if retrieval_model["score_threshold_enabled"]
|
if retrieval_model["score_threshold_enabled"]
|
||||||
else 0.0,
|
else 0.0,
|
||||||
|
@@ -13,7 +13,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
|
|||||||
name: str = "dataset"
|
name: str = "dataset"
|
||||||
description: str = "use this to retrieve a dataset. "
|
description: str = "use this to retrieve a dataset. "
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
top_k: int = 2
|
top_k: int = 4
|
||||||
score_threshold: Optional[float] = None
|
score_threshold: Optional[float] = None
|
||||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||||
return_resource: bool
|
return_resource: bool
|
||||||
|
@@ -78,7 +78,7 @@ default_retrieval_model = {
|
|||||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
"reranking_enable": False,
|
"reranking_enable": False,
|
||||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||||
"top_k": 2,
|
"top_k": 4,
|
||||||
"score_threshold_enabled": False,
|
"score_threshold_enabled": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1149,7 +1149,7 @@ class DocumentService:
|
|||||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
"reranking_enable": False,
|
"reranking_enable": False,
|
||||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||||
"top_k": 2,
|
"top_k": 4,
|
||||||
"score_threshold_enabled": False,
|
"score_threshold_enabled": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1612,7 +1612,7 @@ class DocumentService:
|
|||||||
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
reranking_enable=False,
|
reranking_enable=False,
|
||||||
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
||||||
top_k=2,
|
top_k=4,
|
||||||
score_threshold_enabled=False,
|
score_threshold_enabled=False,
|
||||||
)
|
)
|
||||||
# save dataset
|
# save dataset
|
||||||
|
@@ -18,7 +18,7 @@ default_retrieval_model = {
|
|||||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
"reranking_enable": False,
|
"reranking_enable": False,
|
||||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||||
"top_k": 2,
|
"top_k": 4,
|
||||||
"score_threshold_enabled": False,
|
"score_threshold_enabled": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,7 +66,7 @@ class HitTestingService:
|
|||||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
top_k=retrieval_model.get("top_k", 2),
|
top_k=retrieval_model.get("top_k", 4),
|
||||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||||
if retrieval_model["score_threshold_enabled"]
|
if retrieval_model["score_threshold_enabled"]
|
||||||
else 0.0,
|
else 0.0,
|
||||||
|
@@ -28,7 +28,7 @@ const ExternalKnowledgeBaseCreate: React.FC<ExternalKnowledgeBaseCreateProps> =
|
|||||||
external_knowledge_api_id: '',
|
external_knowledge_api_id: '',
|
||||||
external_knowledge_id: '',
|
external_knowledge_id: '',
|
||||||
external_retrieval_model: {
|
external_retrieval_model: {
|
||||||
top_k: 2,
|
top_k: 4,
|
||||||
score_threshold: 0.5,
|
score_threshold: 0.5,
|
||||||
score_threshold_enabled: false,
|
score_threshold_enabled: false,
|
||||||
},
|
},
|
||||||
|
@@ -49,7 +49,7 @@ const TextAreaWithButton = ({
|
|||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const [isSettingsOpen, setIsSettingsOpen] = useState(false)
|
const [isSettingsOpen, setIsSettingsOpen] = useState(false)
|
||||||
const [externalRetrievalSettings, setExternalRetrievalSettings] = useState({
|
const [externalRetrievalSettings, setExternalRetrievalSettings] = useState({
|
||||||
top_k: 2,
|
top_k: 4,
|
||||||
score_threshold: 0.5,
|
score_threshold: 0.5,
|
||||||
score_threshold_enabled: false,
|
score_threshold_enabled: false,
|
||||||
})
|
})
|
||||||
|
@@ -233,7 +233,7 @@ const DebugConfigurationContext = createContext<IDebugConfiguration>({
|
|||||||
reranking_provider_name: '',
|
reranking_provider_name: '',
|
||||||
reranking_model_name: '',
|
reranking_model_name: '',
|
||||||
},
|
},
|
||||||
top_k: 2,
|
top_k: 4,
|
||||||
score_threshold_enabled: false,
|
score_threshold_enabled: false,
|
||||||
score_threshold: 0.7,
|
score_threshold: 0.7,
|
||||||
datasets: {
|
datasets: {
|
||||||
|
Reference in New Issue
Block a user