feat: add ops trace (#5483)
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
@@ -12,6 +12,8 @@ from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.ops.ops_trace_manager import TraceTask, TraceTaskName
|
||||
from core.ops.utils import measure_time
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank import RerankRunner
|
||||
@@ -38,14 +40,20 @@ default_retrieval_model = {
|
||||
|
||||
|
||||
class DatasetRetrieval:
|
||||
def retrieve(self, app_id: str, user_id: str, tenant_id: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: DatasetEntity,
|
||||
query: str,
|
||||
invoke_from: InvokeFrom,
|
||||
show_retrieve_source: bool,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
|
||||
def __init__(self, application_generate_entity=None):
|
||||
self.application_generate_entity = application_generate_entity
|
||||
|
||||
def retrieve(
|
||||
self, app_id: str, user_id: str, tenant_id: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: DatasetEntity,
|
||||
query: str,
|
||||
invoke_from: InvokeFrom,
|
||||
show_retrieve_source: bool,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
message_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Retrieve dataset.
|
||||
:param app_id: app_id
|
||||
@@ -57,6 +65,7 @@ class DatasetRetrieval:
|
||||
:param invoke_from: invoke from
|
||||
:param show_retrieve_source: show retrieve source
|
||||
:param hit_callback: hit callback
|
||||
:param message_id: message id
|
||||
:param memory: memory
|
||||
:return:
|
||||
"""
|
||||
@@ -113,15 +122,20 @@ class DatasetRetrieval:
|
||||
all_documents = []
|
||||
user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
all_documents = self.single_retrieve(app_id, tenant_id, user_id, user_from, available_datasets, query,
|
||||
model_instance,
|
||||
model_config, planning_strategy)
|
||||
all_documents = self.single_retrieve(
|
||||
app_id, tenant_id, user_id, user_from, available_datasets, query,
|
||||
model_instance,
|
||||
model_config, planning_strategy, message_id
|
||||
)
|
||||
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
all_documents = self.multiple_retrieve(app_id, tenant_id, user_id, user_from,
|
||||
available_datasets, query, retrieve_config.top_k,
|
||||
retrieve_config.score_threshold,
|
||||
retrieve_config.reranking_model.get('reranking_provider_name'),
|
||||
retrieve_config.reranking_model.get('reranking_model_name'))
|
||||
all_documents = self.multiple_retrieve(
|
||||
app_id, tenant_id, user_id, user_from,
|
||||
available_datasets, query, retrieve_config.top_k,
|
||||
retrieve_config.score_threshold,
|
||||
retrieve_config.reranking_model.get('reranking_provider_name'),
|
||||
retrieve_config.reranking_model.get('reranking_model_name'),
|
||||
message_id,
|
||||
)
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
@@ -189,16 +203,18 @@ class DatasetRetrieval:
|
||||
return str("\n".join(document_context_list))
|
||||
return ''
|
||||
|
||||
def single_retrieve(self, app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
):
|
||||
def single_retrieve(
|
||||
self, app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
description = dataset.description
|
||||
@@ -251,27 +267,35 @@ class DatasetRetrieval:
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
with measure_time() as timer:
|
||||
results = RetrievalService.retrieve(
|
||||
retrival_method=retrival_method, dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model
|
||||
)
|
||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||
|
||||
if results:
|
||||
self._on_retrival_end(results)
|
||||
self._on_retrival_end(results, message_id, timer)
|
||||
|
||||
return results
|
||||
return []
|
||||
|
||||
def multiple_retrieve(self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_provider_name: str,
|
||||
reranking_model_name: str):
|
||||
def multiple_retrieve(
|
||||
self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_provider_name: str,
|
||||
reranking_model_name: str,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
threads = []
|
||||
all_documents = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
@@ -297,15 +321,23 @@ class DatasetRetrieval:
|
||||
)
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(query, all_documents,
|
||||
score_threshold,
|
||||
top_k)
|
||||
|
||||
with measure_time() as timer:
|
||||
all_documents = rerank_runner.run(
|
||||
query, all_documents,
|
||||
score_threshold,
|
||||
top_k
|
||||
)
|
||||
self._on_query(query, dataset_ids, app_id, user_from, user_id)
|
||||
|
||||
if all_documents:
|
||||
self._on_retrival_end(all_documents)
|
||||
self._on_retrival_end(all_documents, message_id, timer)
|
||||
|
||||
return all_documents
|
||||
|
||||
def _on_retrival_end(self, documents: list[Document]) -> None:
|
||||
def _on_retrival_end(
|
||||
self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None
|
||||
) -> None:
|
||||
"""Handle retrival end."""
|
||||
for document in documents:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
@@ -324,6 +356,18 @@ class DatasetRetrieval:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
message_id=message_id,
|
||||
documents=documents,
|
||||
timer=timer
|
||||
)
|
||||
)
|
||||
|
||||
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:
|
||||
"""
|
||||
Handle query.
|
||||
|
Reference in New Issue
Block a user