feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,7 +1,7 @@
import math
import threading
from collections import Counter
from typing import Optional, cast
from typing import Any, Optional, cast
from flask import Flask, current_app
@@ -34,7 +34,7 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@@ -140,12 +140,12 @@ class DatasetRetrieval:
user_from,
available_datasets,
query,
retrieve_config.top_k,
retrieve_config.score_threshold,
retrieve_config.rerank_mode,
retrieve_config.top_k or 0,
retrieve_config.score_threshold or 0,
retrieve_config.rerank_mode or "reranking_model",
retrieve_config.reranking_model,
retrieve_config.weights,
retrieve_config.reranking_enabled,
retrieve_config.reranking_enabled or True,
message_id,
)
@@ -300,10 +300,11 @@ class DatasetRetrieval:
metadata=external_document.get("metadata"),
provider="external",
)
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset_id
document.metadata["dataset_name"] = dataset.name
if document.metadata is not None:
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset_id
document.metadata["dataset_name"] = dataset.name
results.append(document)
else:
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
@@ -325,7 +326,7 @@ class DatasetRetrieval:
score_threshold = 0.0
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
score_threshold = retrieval_model_config.get("score_threshold", 0.0)
with measure_time() as timer:
results = RetrievalService.retrieve(
@@ -358,14 +359,14 @@ class DatasetRetrieval:
score_threshold: float,
reranking_mode: str,
reranking_model: Optional[dict] = None,
weights: Optional[dict] = None,
weights: Optional[dict[str, Any]] = None,
reranking_enable: bool = True,
message_id: Optional[str] = None,
):
if not available_datasets:
return []
threads = []
all_documents = []
all_documents: list[Document] = []
dataset_ids = [dataset.id for dataset in available_datasets]
index_type_check = all(
item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
@@ -392,15 +393,18 @@ class DatasetRetrieval:
"The configured knowledge base list have different embedding model, please set reranking model."
)
if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
if weights is not None:
weights["vector_setting"]["embedding_provider_name"] = available_datasets[
0
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
for dataset in available_datasets:
index_type = dataset.indexing_technique
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
@@ -439,21 +443,22 @@ class DatasetRetrieval:
"""Handle retrieval end."""
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
if document.metadata is not None:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
# add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
db.session.commit()
db.session.commit()
# get tracing instance
trace_manager: TraceQueueManager = (
trace_manager: Optional[TraceQueueManager] = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
@@ -504,10 +509,11 @@ class DatasetRetrieval:
metadata=external_document.get("metadata"),
provider="external",
)
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset_id
document.metadata["dataset_name"] = dataset.name
if document.metadata is not None:
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset_id
document.metadata["dataset_name"] = dataset.name
all_documents.append(document)
else:
# get retrieval model , if the model is not setting , using default
@@ -607,19 +613,20 @@ class DatasetRetrieval:
tools.append(tool)
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,
top_k=retrieve_config.top_k or 2,
score_threshold=retrieve_config.score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
)
if retrieve_config.reranking_model is not None:
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,
top_k=retrieve_config.top_k or 2,
score_threshold=retrieve_config.score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
)
tools.append(tool)
tools.append(tool)
return tools
@@ -635,10 +642,11 @@ class DatasetRetrieval:
query_keywords = keyword_table_handler.extract_keywords(query, None)
documents_keywords = []
for document in documents:
# get the document keywords
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
document.metadata["keywords"] = document_keywords
documents_keywords.append(document_keywords)
if document.metadata is not None:
# get the document keywords
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
document.metadata["keywords"] = document_keywords
documents_keywords.append(document_keywords)
# Counter query keywords(TF)
query_keyword_counts = Counter(query_keywords)
@@ -696,8 +704,9 @@ class DatasetRetrieval:
for document, score in zip(documents, similarities):
# format document
document.metadata["score"] = score
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
if document.metadata is not None:
document.metadata["score"] = score
documents = sorted(documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return documents[:top_k] if top_k else documents
def calculate_vector_score(
@@ -705,10 +714,12 @@ class DatasetRetrieval:
) -> list[Document]:
filter_documents = []
for document in all_documents:
if score_threshold is None or document.metadata["score"] >= score_threshold:
if score_threshold is None or (document.metadata and document.metadata.get("score", 0) >= score_threshold):
filter_documents.append(document)
if not filter_documents:
return []
filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True)
filter_documents = sorted(
filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
)
return filter_documents[:top_k] if top_k else filter_documents