Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
|
||||
from typing import Optional
|
||||
from flask import current_app, Flask
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
@@ -50,12 +52,24 @@ class RetrievalService:
|
||||
|
||||
if documents:
|
||||
if reranking_model and search_method == 'semantic_search':
|
||||
rerank = ModelFactory.get_reranking_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=reranking_model['reranking_provider_name'],
|
||||
model_name=reranking_model['reranking_model_name']
|
||||
)
|
||||
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model['reranking_model_name']
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents.extend(rerank_runner.run(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents)
|
||||
))
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
|
||||
@@ -81,15 +95,23 @@ class RetrievalService:
|
||||
)
|
||||
if documents:
|
||||
if reranking_model and search_method == 'full_text_search':
|
||||
rerank = ModelFactory.get_reranking_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=reranking_model['reranking_provider_name'],
|
||||
model_name=reranking_model['reranking_model_name']
|
||||
)
|
||||
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model['reranking_model_name']
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents.extend(rerank_runner.run(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents)
|
||||
))
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user