feat: add ops trace (#5483)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Joe
2024-06-26 17:33:29 +08:00
committed by GitHub
parent 31a061ebaa
commit 4e2de638af
58 changed files with 3553 additions and 622 deletions

View File

@@ -12,6 +12,8 @@ from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.ops.ops_trace_manager import TraceTask, TraceTaskName
from core.ops.utils import measure_time
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
from core.rag.rerank.rerank import RerankRunner
@@ -38,14 +40,20 @@ default_retrieval_model = {
class DatasetRetrieval:
def retrieve(self, app_id: str, user_id: str, tenant_id: str,
model_config: ModelConfigWithCredentialsEntity,
config: DatasetEntity,
query: str,
invoke_from: InvokeFrom,
show_retrieve_source: bool,
hit_callback: DatasetIndexToolCallbackHandler,
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
def __init__(self, application_generate_entity=None):
self.application_generate_entity = application_generate_entity
def retrieve(
self, app_id: str, user_id: str, tenant_id: str,
model_config: ModelConfigWithCredentialsEntity,
config: DatasetEntity,
query: str,
invoke_from: InvokeFrom,
show_retrieve_source: bool,
hit_callback: DatasetIndexToolCallbackHandler,
message_id: str,
memory: Optional[TokenBufferMemory] = None,
) -> Optional[str]:
"""
Retrieve dataset.
:param app_id: app_id
@@ -57,6 +65,7 @@ class DatasetRetrieval:
:param invoke_from: invoke from
:param show_retrieve_source: show retrieve source
:param hit_callback: hit callback
:param message_id: message id
:param memory: memory
:return:
"""
@@ -113,15 +122,20 @@ class DatasetRetrieval:
all_documents = []
user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
all_documents = self.single_retrieve(app_id, tenant_id, user_id, user_from, available_datasets, query,
model_instance,
model_config, planning_strategy)
all_documents = self.single_retrieve(
app_id, tenant_id, user_id, user_from, available_datasets, query,
model_instance,
model_config, planning_strategy, message_id
)
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
all_documents = self.multiple_retrieve(app_id, tenant_id, user_id, user_from,
available_datasets, query, retrieve_config.top_k,
retrieve_config.score_threshold,
retrieve_config.reranking_model.get('reranking_provider_name'),
retrieve_config.reranking_model.get('reranking_model_name'))
all_documents = self.multiple_retrieve(
app_id, tenant_id, user_id, user_from,
available_datasets, query, retrieve_config.top_k,
retrieve_config.score_threshold,
retrieve_config.reranking_model.get('reranking_provider_name'),
retrieve_config.reranking_model.get('reranking_model_name'),
message_id,
)
document_score_list = {}
for item in all_documents:
@@ -189,16 +203,18 @@ class DatasetRetrieval:
return str("\n".join(document_context_list))
return ''
def single_retrieve(self, app_id: str,
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
):
def single_retrieve(
self, app_id: str,
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
message_id: Optional[str] = None,
):
tools = []
for dataset in available_datasets:
description = dataset.description
@@ -251,27 +267,35 @@ class DatasetRetrieval:
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
with measure_time() as timer:
results = RetrievalService.retrieve(
retrival_method=retrival_method, dataset_id=dataset.id,
query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
if results:
self._on_retrival_end(results)
self._on_retrival_end(results, message_id, timer)
return results
return []
def multiple_retrieve(self,
app_id: str,
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
top_k: int,
score_threshold: float,
reranking_provider_name: str,
reranking_model_name: str):
def multiple_retrieve(
self,
app_id: str,
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
top_k: int,
score_threshold: float,
reranking_provider_name: str,
reranking_model_name: str,
message_id: Optional[str] = None,
):
threads = []
all_documents = []
dataset_ids = [dataset.id for dataset in available_datasets]
@@ -297,15 +321,23 @@ class DatasetRetrieval:
)
rerank_runner = RerankRunner(rerank_model_instance)
all_documents = rerank_runner.run(query, all_documents,
score_threshold,
top_k)
with measure_time() as timer:
all_documents = rerank_runner.run(
query, all_documents,
score_threshold,
top_k
)
self._on_query(query, dataset_ids, app_id, user_from, user_id)
if all_documents:
self._on_retrival_end(all_documents)
self._on_retrival_end(all_documents, message_id, timer)
return all_documents
def _on_retrival_end(self, documents: list[Document]) -> None:
def _on_retrival_end(
self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None
) -> None:
"""Handle retrival end."""
for document in documents:
query = db.session.query(DocumentSegment).filter(
@@ -324,6 +356,18 @@ class DatasetRetrieval:
db.session.commit()
# get tracing instance
trace_manager = self.application_generate_entity.trace_manager if self.application_generate_entity else None
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE,
message_id=message_id,
documents=documents,
timer=timer
)
)
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:
"""
Handle query.