Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
@@ -11,7 +10,9 @@ from langchain.schema import Document
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DocumentSegment, DatasetQuery
|
||||
@@ -47,11 +48,14 @@ class HitTestingService:
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get embedding model
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
all_documents = []
|
||||
@@ -93,14 +97,22 @@ class HitTestingService:
|
||||
thread.join()
|
||||
|
||||
if retrieval_model['search_method'] == 'hybrid_search':
|
||||
hybrid_rerank = ModelFactory.get_reranking_model(
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
|
||||
model_name=retrieval_model['reranking_model']['reranking_model_name']
|
||||
provider=retrieval_model['reranking_model']['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=retrieval_model['reranking_model']['reranking_model_name']
|
||||
)
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
|
||||
top_n=retrieval_model['top_k'],
|
||||
user=f"account-{account.id}"
|
||||
)
|
||||
all_documents = hybrid_rerank.rerank(query, all_documents,
|
||||
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
|
||||
retrieval_model['top_k'])
|
||||
|
||||
end = time.perf_counter()
|
||||
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
|
||||
|
Reference in New Issue
Block a user