Fix/remove tsne position test (#5858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
Jyong
2024-07-02 17:57:42 +08:00
committed by GitHub
parent d468f8b75c
commit 0944ca9d91
3 changed files with 7 additions and 89 deletions

View File

@@ -4,10 +4,6 @@ import time
import numpy as np
from sklearn.manifold import TSNE
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
from core.rag.retrieval.retrival_methods import RetrievalMethod
@@ -45,17 +41,6 @@ class HitTestingService:
if not retrieval_model:
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
# get embedding model
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
provider=dataset.embedding_model_provider,
model=dataset.embedding_model
)
embeddings = CacheEmbedding(embedding_model)
all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
dataset_id=dataset.id,
query=query,
@@ -80,20 +65,10 @@ class HitTestingService:
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(dataset, embeddings, query, all_documents)
return cls.compact_retrieve_response(dataset, query, all_documents)
@classmethod
def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: list[Document]):
text_embeddings = [
embeddings.embed_query(query)
]
text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents]))
tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings)
query_position = tsne_position_data.pop(0)
def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
i = 0
records = []
for document in documents:
@@ -113,7 +88,6 @@ class HitTestingService:
record = {
"segment": segment,
"score": document.metadata.get('score', None),
"tsne_position": tsne_position_data[i]
}
records.append(record)
@@ -123,7 +97,6 @@ class HitTestingService:
return {
"query": {
"content": query,
"tsne_position": query_position,
},
"records": records
}