chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -32,14 +32,11 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -48,15 +45,18 @@ class DatasetRetrieval:
|
||||
self.application_generate_entity = application_generate_entity
|
||||
|
||||
def retrieve(
|
||||
self, app_id: str, user_id: str, tenant_id: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: DatasetEntity,
|
||||
query: str,
|
||||
invoke_from: InvokeFrom,
|
||||
show_retrieve_source: bool,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
message_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
self,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: DatasetEntity,
|
||||
query: str,
|
||||
invoke_from: InvokeFrom,
|
||||
show_retrieve_source: bool,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
message_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Retrieve dataset.
|
||||
@@ -84,16 +84,12 @@ class DatasetRetrieval:
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.provider,
|
||||
model=model_config.model
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
|
||||
)
|
||||
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model_config.credentials
|
||||
model=model_config.model, credentials=model_config.credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
@@ -102,39 +98,46 @@ class DatasetRetrieval:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features \
|
||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
continue
|
||||
|
||||
# pass if dataset is not available
|
||||
if (dataset and dataset.available_document_count == 0
|
||||
and dataset.available_document_count == 0):
|
||||
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
|
||||
continue
|
||||
|
||||
available_datasets.append(dataset)
|
||||
all_documents = []
|
||||
user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'
|
||||
user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
all_documents = self.single_retrieve(
|
||||
app_id, tenant_id, user_id, user_from, available_datasets, query,
|
||||
app_id,
|
||||
tenant_id,
|
||||
user_id,
|
||||
user_from,
|
||||
available_datasets,
|
||||
query,
|
||||
model_instance,
|
||||
model_config, planning_strategy, message_id
|
||||
model_config,
|
||||
planning_strategy,
|
||||
message_id,
|
||||
)
|
||||
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
all_documents = self.multiple_retrieve(
|
||||
app_id, tenant_id, user_id, user_from,
|
||||
available_datasets, query, retrieve_config.top_k,
|
||||
app_id,
|
||||
tenant_id,
|
||||
user_id,
|
||||
user_from,
|
||||
available_datasets,
|
||||
query,
|
||||
retrieve_config.top_k,
|
||||
retrieve_config.score_threshold,
|
||||
retrieve_config.rerank_mode,
|
||||
retrieve_config.reranking_model,
|
||||
@@ -145,89 +148,89 @@ class DatasetRetrieval:
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
if item.metadata.get('score'):
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
if item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
).all()
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(segments,
|
||||
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
|
||||
float('inf')))
|
||||
sorted_segments = sorted(
|
||||
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
|
||||
)
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
|
||||
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
|
||||
else:
|
||||
document_context_list.append(segment.get_sign_content())
|
||||
if show_retrieve_source:
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=segment.dataset_id
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).first()
|
||||
document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
source = {
|
||||
'position': resource_number,
|
||||
'dataset_id': dataset.id,
|
||||
'dataset_name': dataset.name,
|
||||
'document_id': document.id,
|
||||
'document_name': document.name,
|
||||
'data_source_type': document.data_source_type,
|
||||
'segment_id': segment.id,
|
||||
'retriever_from': invoke_from.to_source(),
|
||||
'score': document_score_list.get(segment.index_node_id, None)
|
||||
"position": resource_number,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": invoke_from.to_source(),
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
}
|
||||
|
||||
if invoke_from.to_source() == 'dev':
|
||||
source['hit_count'] = segment.hit_count
|
||||
source['word_count'] = segment.word_count
|
||||
source['segment_position'] = segment.position
|
||||
source['index_node_hash'] = segment.index_node_hash
|
||||
if invoke_from.to_source() == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source['content'] = segment.content
|
||||
source["content"] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
if hit_callback:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
return ''
|
||||
return ""
|
||||
|
||||
def single_retrieve(
|
||||
self, app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
message_id: Optional[str] = None,
|
||||
self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
description = "useful for when you want to answer queries about the " + dataset.name
|
||||
|
||||
description = description.replace('\n', '').replace('\r', '')
|
||||
description = description.replace("\n", "").replace("\r", "")
|
||||
message_tool = PromptMessageTool(
|
||||
name=dataset.id,
|
||||
description=description,
|
||||
@@ -235,14 +238,15 @@ class DatasetRetrieval:
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
tools.append(message_tool)
|
||||
dataset_id = None
|
||||
if planning_strategy == PlanningStrategy.REACT_ROUTER:
|
||||
react_multi_dataset_router = ReactMultiDatasetRouter()
|
||||
dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance,
|
||||
user_id, tenant_id)
|
||||
dataset_id = react_multi_dataset_router.invoke(
|
||||
query, tools, model_config, model_instance, user_id, tenant_id
|
||||
)
|
||||
|
||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||
function_call_router = FunctionCallMultiDatasetRouter()
|
||||
@@ -250,37 +254,37 @@ class DatasetRetrieval:
|
||||
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
retrieval_model_config = dataset.retrieval_model \
|
||||
if dataset.retrieval_model else default_retrieval_model
|
||||
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config['top_k']
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
# get retrieval method
|
||||
if dataset.indexing_technique == "economy":
|
||||
retrieval_method = 'keyword_search'
|
||||
retrieval_method = "keyword_search"
|
||||
else:
|
||||
retrieval_method = retrieval_model_config['search_method']
|
||||
retrieval_method = retrieval_model_config["search_method"]
|
||||
# get reranking model
|
||||
reranking_model = retrieval_model_config['reranking_model'] \
|
||||
if retrieval_model_config['reranking_enable'] else None
|
||||
reranking_model = (
|
||||
retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None
|
||||
)
|
||||
# get score threshold
|
||||
score_threshold = .0
|
||||
score_threshold = 0.0
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
with measure_time() as timer:
|
||||
results = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_method, dataset_id=dataset.id,
|
||||
retrieval_method=retrieval_method,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'),
|
||||
weights=retrieval_model_config.get('weights', None),
|
||||
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
|
||||
weights=retrieval_model_config.get("weights", None),
|
||||
)
|
||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||
|
||||
@@ -291,20 +295,20 @@ class DatasetRetrieval:
|
||||
return []
|
||||
|
||||
def multiple_retrieve(
|
||||
self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_mode: str,
|
||||
reranking_model: Optional[dict] = None,
|
||||
weights: Optional[dict] = None,
|
||||
reranking_enable: bool = True,
|
||||
message_id: Optional[str] = None,
|
||||
self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_mode: str,
|
||||
reranking_model: Optional[dict] = None,
|
||||
weights: Optional[dict] = None,
|
||||
reranking_enable: bool = True,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
threads = []
|
||||
all_documents = []
|
||||
@@ -312,13 +316,16 @@ class DatasetRetrieval:
|
||||
index_type = None
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset.id,
|
||||
'query': query,
|
||||
'top_k': top_k,
|
||||
'all_documents': all_documents,
|
||||
})
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"dataset_id": dataset.id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
@@ -327,16 +334,10 @@ class DatasetRetrieval:
|
||||
with measure_time() as timer:
|
||||
if reranking_enable:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(
|
||||
tenant_id, reranking_mode,
|
||||
reranking_model, weights, False
|
||||
)
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k
|
||||
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
|
||||
)
|
||||
else:
|
||||
if index_type == "economy":
|
||||
@@ -357,30 +358,26 @@ class DatasetRetrieval:
|
||||
"""Handle retrieval end."""
|
||||
for document in documents:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata['doc_id']
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if 'dataset_id' in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False
|
||||
)
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# get tracing instance
|
||||
trace_manager: TraceQueueManager = self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
trace_manager: TraceQueueManager = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
message_id=message_id,
|
||||
documents=documents,
|
||||
timer=timer
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
||||
)
|
||||
)
|
||||
|
||||
@@ -395,10 +392,10 @@ class DatasetRetrieval:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source='app',
|
||||
source="app",
|
||||
source_app_id=app_id,
|
||||
created_by_role=user_from,
|
||||
created_by=user_id
|
||||
created_by=user_id,
|
||||
)
|
||||
dataset_queries.append(dataset_query)
|
||||
if dataset_queries:
|
||||
@@ -407,9 +404,7 @@ class DatasetRetrieval:
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
@@ -419,38 +414,42 @@ class DatasetRetrieval:
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(retrieval_method='keyword_search',
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k
|
||||
)
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
if top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(retrieval_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=retrieval_model.get('score_threshold', .0)
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model.get('reranking_model', None)
|
||||
if retrieval_model['reranking_enable'] else None,
|
||||
reranking_mode=retrieval_model.get('reranking_mode')
|
||||
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
||||
weights=retrieval_model.get('weights', None),
|
||||
)
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model["search_method"],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else None,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode")
|
||||
if retrieval_model.get("reranking_mode")
|
||||
else "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
|
||||
def to_dataset_retriever_tool(self, tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity,
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler) \
|
||||
-> Optional[list[DatasetRetrieverBaseTool]]:
|
||||
def to_dataset_retriever_tool(
|
||||
self,
|
||||
tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity,
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
) -> Optional[list[DatasetRetrieverBaseTool]]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param tenant_id: tenant id
|
||||
@@ -464,18 +463,14 @@ class DatasetRetrieval:
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
continue
|
||||
|
||||
# pass if dataset is not available
|
||||
if (dataset and dataset.available_document_count == 0
|
||||
and dataset.available_document_count == 0):
|
||||
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
|
||||
continue
|
||||
|
||||
available_datasets.append(dataset)
|
||||
@@ -483,22 +478,18 @@ class DatasetRetrieval:
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
# get retrieval model config
|
||||
default_retrieval_model = {
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
for dataset in available_datasets:
|
||||
retrieval_model_config = dataset.retrieval_model \
|
||||
if dataset.retrieval_model else default_retrieval_model
|
||||
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config['top_k']
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
|
||||
# get score threshold
|
||||
score_threshold = None
|
||||
@@ -512,7 +503,7 @@ class DatasetRetrieval:
|
||||
score_threshold=score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source()
|
||||
retriever_from=invoke_from.to_source(),
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
@@ -525,8 +516,8 @@ class DatasetRetrieval:
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
|
||||
reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
|
||||
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
|
||||
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
@@ -547,7 +538,7 @@ class DatasetRetrieval:
|
||||
for document in documents:
|
||||
# get the document keywords
|
||||
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
|
||||
document.metadata['keywords'] = document_keywords
|
||||
document.metadata["keywords"] = document_keywords
|
||||
documents_keywords.append(document_keywords)
|
||||
|
||||
# Counter query keywords(TF)
|
||||
@@ -606,21 +597,19 @@ class DatasetRetrieval:
|
||||
|
||||
for document, score in zip(documents, similarities):
|
||||
# format document
|
||||
document.metadata['score'] = score
|
||||
documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True)
|
||||
document.metadata["score"] = score
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents[:top_k] if top_k else documents
|
||||
|
||||
def calculate_vector_score(self, all_documents: list[Document],
|
||||
top_k: int, score_threshold: float) -> list[Document]:
|
||||
def calculate_vector_score(
|
||||
self, all_documents: list[Document], top_k: int, score_threshold: float
|
||||
) -> list[Document]:
|
||||
filter_documents = []
|
||||
for document in all_documents:
|
||||
if score_threshold is None or document.metadata['score'] >= score_threshold:
|
||||
if score_threshold is None or document.metadata["score"] >= score_threshold:
|
||||
filter_documents.append(document)
|
||||
|
||||
if not filter_documents:
|
||||
return []
|
||||
filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True)
|
||||
filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return filter_documents[:top_k] if top_k else filter_documents
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user