feat: mypy for all type check (#10921)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
import threading
|
||||
from collections import Counter
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
@@ -34,7 +34,7 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
@@ -140,12 +140,12 @@ class DatasetRetrieval:
|
||||
user_from,
|
||||
available_datasets,
|
||||
query,
|
||||
retrieve_config.top_k,
|
||||
retrieve_config.score_threshold,
|
||||
retrieve_config.rerank_mode,
|
||||
retrieve_config.top_k or 0,
|
||||
retrieve_config.score_threshold or 0,
|
||||
retrieve_config.rerank_mode or "reranking_model",
|
||||
retrieve_config.reranking_model,
|
||||
retrieve_config.weights,
|
||||
retrieve_config.reranking_enabled,
|
||||
retrieve_config.reranking_enabled or True,
|
||||
message_id,
|
||||
)
|
||||
|
||||
@@ -300,10 +300,11 @@ class DatasetRetrieval:
|
||||
metadata=external_document.get("metadata"),
|
||||
provider="external",
|
||||
)
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset_id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
if document.metadata is not None:
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset_id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
else:
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
@@ -325,7 +326,7 @@ class DatasetRetrieval:
|
||||
score_threshold = 0.0
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
score_threshold = retrieval_model_config.get("score_threshold", 0.0)
|
||||
|
||||
with measure_time() as timer:
|
||||
results = RetrievalService.retrieve(
|
||||
@@ -358,14 +359,14 @@ class DatasetRetrieval:
|
||||
score_threshold: float,
|
||||
reranking_mode: str,
|
||||
reranking_model: Optional[dict] = None,
|
||||
weights: Optional[dict] = None,
|
||||
weights: Optional[dict[str, Any]] = None,
|
||||
reranking_enable: bool = True,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
if not available_datasets:
|
||||
return []
|
||||
threads = []
|
||||
all_documents = []
|
||||
all_documents: list[Document] = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
index_type_check = all(
|
||||
item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
|
||||
@@ -392,15 +393,18 @@ class DatasetRetrieval:
|
||||
"The configured knowledge base list have different embedding model, please set reranking model."
|
||||
)
|
||||
if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
|
||||
weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
if weights is not None:
|
||||
weights["vector_setting"]["embedding_provider_name"] = available_datasets[
|
||||
0
|
||||
].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset.id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
@@ -439,21 +443,22 @@ class DatasetRetrieval:
|
||||
"""Handle retrieval end."""
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
for document in dify_documents:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
if document.metadata is not None:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
# add hit count to document segment
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
|
||||
db.session.commit()
|
||||
db.session.commit()
|
||||
|
||||
# get tracing instance
|
||||
trace_manager: TraceQueueManager = (
|
||||
trace_manager: Optional[TraceQueueManager] = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
@@ -504,10 +509,11 @@ class DatasetRetrieval:
|
||||
metadata=external_document.get("metadata"),
|
||||
provider="external",
|
||||
)
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset_id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
if document.metadata is not None:
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset_id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
all_documents.append(document)
|
||||
else:
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
@@ -607,19 +613,20 @@ class DatasetRetrieval:
|
||||
|
||||
tools.append(tool)
|
||||
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||
tenant_id=tenant_id,
|
||||
top_k=retrieve_config.top_k or 2,
|
||||
score_threshold=retrieve_config.score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
|
||||
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
|
||||
)
|
||||
if retrieve_config.reranking_model is not None:
|
||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||
tenant_id=tenant_id,
|
||||
top_k=retrieve_config.top_k or 2,
|
||||
score_threshold=retrieve_config.score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
|
||||
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
@@ -635,10 +642,11 @@ class DatasetRetrieval:
|
||||
query_keywords = keyword_table_handler.extract_keywords(query, None)
|
||||
documents_keywords = []
|
||||
for document in documents:
|
||||
# get the document keywords
|
||||
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
|
||||
document.metadata["keywords"] = document_keywords
|
||||
documents_keywords.append(document_keywords)
|
||||
if document.metadata is not None:
|
||||
# get the document keywords
|
||||
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
|
||||
document.metadata["keywords"] = document_keywords
|
||||
documents_keywords.append(document_keywords)
|
||||
|
||||
# Counter query keywords(TF)
|
||||
query_keyword_counts = Counter(query_keywords)
|
||||
@@ -696,8 +704,9 @@ class DatasetRetrieval:
|
||||
|
||||
for document, score in zip(documents, similarities):
|
||||
# format document
|
||||
document.metadata["score"] = score
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
if document.metadata is not None:
|
||||
document.metadata["score"] = score
|
||||
documents = sorted(documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
|
||||
return documents[:top_k] if top_k else documents
|
||||
|
||||
def calculate_vector_score(
|
||||
@@ -705,10 +714,12 @@ class DatasetRetrieval:
|
||||
) -> list[Document]:
|
||||
filter_documents = []
|
||||
for document in all_documents:
|
||||
if score_threshold is None or document.metadata["score"] >= score_threshold:
|
||||
if score_threshold is None or (document.metadata and document.metadata.get("score", 0) >= score_threshold):
|
||||
filter_documents.append(document)
|
||||
|
||||
if not filter_documents:
|
||||
return []
|
||||
filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
filter_documents = sorted(
|
||||
filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
|
||||
)
|
||||
return filter_documents[:top_k] if top_k else filter_documents
|
||||
|
Reference in New Issue
Block a user