chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -13,13 +13,18 @@ from core.rag.rerank.entity.weight import VectorSetting, Weights
|
||||
|
||||
|
||||
class WeightRerankRunner:
|
||||
|
||||
def __init__(self, tenant_id: str, weights: Weights) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.weights = weights
|
||||
|
||||
def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
documents: list[Document],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Run rerank model
|
||||
:param query: search query
|
||||
@@ -34,8 +39,8 @@ class WeightRerankRunner:
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
if document.metadata['doc_id'] not in doc_id:
|
||||
doc_id.append(document.metadata['doc_id'])
|
||||
if document.metadata["doc_id"] not in doc_id:
|
||||
doc_id.append(document.metadata["doc_id"])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
|
||||
@@ -47,13 +52,15 @@ class WeightRerankRunner:
|
||||
query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
|
||||
for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
|
||||
# format document
|
||||
score = self.weights.vector_setting.vector_weight * query_vector_score + \
|
||||
self.weights.keyword_setting.keyword_weight * query_score
|
||||
score = (
|
||||
self.weights.vector_setting.vector_weight * query_vector_score
|
||||
+ self.weights.keyword_setting.keyword_weight * query_score
|
||||
)
|
||||
if score_threshold and score < score_threshold:
|
||||
continue
|
||||
document.metadata['score'] = score
|
||||
document.metadata["score"] = score
|
||||
rerank_documents.append(document)
|
||||
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True)
|
||||
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return rerank_documents[:top_n] if top_n else rerank_documents
|
||||
|
||||
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
|
||||
@@ -70,7 +77,7 @@ class WeightRerankRunner:
|
||||
for document in documents:
|
||||
# get the document keywords
|
||||
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
|
||||
document.metadata['keywords'] = document_keywords
|
||||
document.metadata["keywords"] = document_keywords
|
||||
documents_keywords.append(document_keywords)
|
||||
|
||||
# Counter query keywords(TF)
|
||||
@@ -132,8 +139,9 @@ class WeightRerankRunner:
|
||||
|
||||
return similarities
|
||||
|
||||
def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document],
|
||||
vector_setting: VectorSetting) -> list[float]:
|
||||
def _calculate_cosine(
|
||||
self, tenant_id: str, query: str, documents: list[Document], vector_setting: VectorSetting
|
||||
) -> list[float]:
|
||||
"""
|
||||
Calculate Cosine scores
|
||||
:param query: search query
|
||||
@@ -149,15 +157,14 @@ class WeightRerankRunner:
|
||||
tenant_id=tenant_id,
|
||||
provider=vector_setting.embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=vector_setting.embedding_model_name
|
||||
|
||||
model=vector_setting.embedding_model_name,
|
||||
)
|
||||
cache_embedding = CacheEmbedding(embedding_model)
|
||||
query_vector = cache_embedding.embed_query(query)
|
||||
for document in documents:
|
||||
# calculate cosine similarity
|
||||
if 'score' in document.metadata:
|
||||
query_vector_scores.append(document.metadata['score'])
|
||||
if "score" in document.metadata:
|
||||
query_vector_scores.append(document.metadata["score"])
|
||||
else:
|
||||
# transform to NumPy
|
||||
vec1 = np.array(query_vector)
|
||||
|
Reference in New Issue
Block a user