Feat/fix ops trace (#5672)
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
@@ -94,5 +94,15 @@ class ToolTraceInfo(BaseTraceInfo):
|
||||
|
||||
|
||||
class GenerateNameTraceInfo(BaseTraceInfo):
|
||||
conversation_id: str
|
||||
conversation_id: Optional[str] = None
|
||||
tenant_id: str
|
||||
|
||||
trace_info_info_map = {
|
||||
'WorkflowTraceInfo': WorkflowTraceInfo,
|
||||
'MessageTraceInfo': MessageTraceInfo,
|
||||
'ModerationTraceInfo': ModerationTraceInfo,
|
||||
'SuggestedQuestionTraceInfo': SuggestedQuestionTraceInfo,
|
||||
'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo,
|
||||
'ToolTraceInfo': ToolTraceInfo,
|
||||
'GenerateNameTraceInfo': GenerateNameTraceInfo,
|
||||
}
|
@@ -147,6 +147,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
# add span
|
||||
if trace_info.message_id:
|
||||
span_data = LangfuseSpan(
|
||||
id=node_execution_id,
|
||||
name=f"{node_name}_{node_execution_id}",
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
@@ -160,6 +161,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
)
|
||||
else:
|
||||
span_data = LangfuseSpan(
|
||||
id=node_execution_id,
|
||||
name=f"{node_name}_{node_execution_id}",
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
@@ -173,6 +175,30 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
|
||||
self.add_span(langfuse_span_data=span_data)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
total_token = metadata.get("total_tokens", 0)
|
||||
# add generation
|
||||
generation_usage = GenerationUsage(
|
||||
totalTokens=total_token,
|
||||
)
|
||||
|
||||
node_generation_data = LangfuseGeneration(
|
||||
name=f"generation_{node_execution_id}",
|
||||
trace_id=trace_id,
|
||||
parent_observation_id=node_execution_id,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
metadata=metadata,
|
||||
level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR,
|
||||
status_message=trace_info.error if trace_info.error else "",
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
self.add_generation(langfuse_generation_data=node_generation_data)
|
||||
|
||||
def message_trace(
|
||||
self, trace_info: MessageTraceInfo, **kwargs
|
||||
):
|
||||
@@ -186,7 +212,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser = db.session.query(EndUser).filter(
|
||||
EndUser.id == message_data.from_end_user_id
|
||||
).first().session_id
|
||||
).first()
|
||||
user_id = end_user_data.session_id
|
||||
|
||||
trace_data = LangfuseTrace(
|
||||
@@ -220,6 +246,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
output=trace_info.answer_tokens,
|
||||
total=trace_info.total_tokens,
|
||||
unit=UnitEnum.TOKENS,
|
||||
totalCost=message_data.total_price,
|
||||
)
|
||||
|
||||
langfuse_generation_data = LangfuseGeneration(
|
||||
@@ -303,7 +330,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
|
||||
level=LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR,
|
||||
status_message=trace_info.error,
|
||||
)
|
||||
|
||||
|
@@ -1,16 +1,17 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask import current_app
|
||||
|
||||
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import (
|
||||
LangfuseConfig,
|
||||
LangSmithConfig,
|
||||
@@ -31,6 +32,7 @@ from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig
|
||||
from models.workflow import WorkflowAppLog, WorkflowRun
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
provider_config_map = {
|
||||
TracingProviderEnum.LANGFUSE.value: {
|
||||
@@ -105,7 +107,7 @@ class OpsTraceManager:
|
||||
return config_class(**new_config).model_dump()
|
||||
|
||||
@classmethod
|
||||
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config:dict):
|
||||
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
|
||||
"""
|
||||
Decrypt tracing config
|
||||
:param tracing_provider: tracing provider
|
||||
@@ -295,11 +297,9 @@ class TraceTask:
|
||||
self.kwargs = kwargs
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def execute(self, trace_instance: BaseTraceInstance):
|
||||
def execute(self):
|
||||
method_name, trace_info = self.preprocess()
|
||||
if trace_instance:
|
||||
method = trace_instance.trace
|
||||
method(trace_info)
|
||||
return trace_info
|
||||
|
||||
def preprocess(self):
|
||||
if self.trace_type == TraceTaskName.CONVERSATION_TRACE:
|
||||
@@ -372,7 +372,7 @@ class TraceTask:
|
||||
}
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
workflow_data=workflow_run,
|
||||
workflow_data=workflow_run.to_dict(),
|
||||
conversation_id=conversation_id,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -427,7 +427,8 @@ class TraceTask:
|
||||
message_tokens = message_data.message_tokens
|
||||
|
||||
message_trace_info = MessageTraceInfo(
|
||||
message_data=message_data,
|
||||
message_id=message_id,
|
||||
message_data=message_data.to_dict(),
|
||||
conversation_model=conversation_mode,
|
||||
message_tokens=message_tokens,
|
||||
answer_tokens=message_data.answer_tokens,
|
||||
@@ -469,7 +470,7 @@ class TraceTask:
|
||||
moderation_trace_info = ModerationTraceInfo(
|
||||
message_id=workflow_app_log_id if workflow_app_log_id else message_id,
|
||||
inputs=inputs,
|
||||
message_data=message_data,
|
||||
message_data=message_data.to_dict(),
|
||||
flagged=moderation_result.flagged,
|
||||
action=moderation_result.action,
|
||||
preset_response=moderation_result.preset_response,
|
||||
@@ -508,7 +509,7 @@ class TraceTask:
|
||||
|
||||
suggested_question_trace_info = SuggestedQuestionTraceInfo(
|
||||
message_id=workflow_app_log_id if workflow_app_log_id else message_id,
|
||||
message_data=message_data,
|
||||
message_data=message_data.to_dict(),
|
||||
inputs=message_data.message,
|
||||
outputs=message_data.answer,
|
||||
start_time=timer.get("start"),
|
||||
@@ -550,11 +551,11 @@ class TraceTask:
|
||||
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
|
||||
message_id=message_id,
|
||||
inputs=message_data.query if message_data.query else message_data.inputs,
|
||||
documents=documents,
|
||||
documents=[doc.model_dump() for doc in documents],
|
||||
start_time=timer.get("start"),
|
||||
end_time=timer.get("end"),
|
||||
metadata=metadata,
|
||||
message_data=message_data,
|
||||
message_data=message_data.to_dict(),
|
||||
)
|
||||
|
||||
return dataset_retrieval_trace_info
|
||||
@@ -613,7 +614,7 @@ class TraceTask:
|
||||
|
||||
tool_trace_info = ToolTraceInfo(
|
||||
message_id=message_id,
|
||||
message_data=message_data,
|
||||
message_data=message_data.to_dict(),
|
||||
tool_name=tool_name,
|
||||
start_time=timer.get("start") if timer else created_time,
|
||||
end_time=timer.get("end") if timer else end_time,
|
||||
@@ -657,31 +658,71 @@ class TraceTask:
|
||||
return generate_name_trace_info
|
||||
|
||||
|
||||
trace_manager_timer = None
|
||||
trace_manager_queue = queue.Queue()
|
||||
trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 1))
|
||||
trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
|
||||
|
||||
|
||||
class TraceQueueManager:
|
||||
def __init__(self, app_id=None, conversation_id=None, message_id=None):
|
||||
tracing_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id)
|
||||
self.queue = queue.Queue()
|
||||
self.is_running = True
|
||||
self.thread = threading.Thread(
|
||||
target=self.process_queue, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'trace_instance': tracing_instance
|
||||
}
|
||||
)
|
||||
self.thread.start()
|
||||
global trace_manager_timer
|
||||
|
||||
def stop(self):
|
||||
self.is_running = False
|
||||
|
||||
def process_queue(self, flask_app: Flask, trace_instance: BaseTraceInstance):
|
||||
with flask_app.app_context():
|
||||
while self.is_running:
|
||||
try:
|
||||
task = self.queue.get(timeout=60)
|
||||
task.execute(trace_instance)
|
||||
self.queue.task_done()
|
||||
except queue.Empty:
|
||||
self.stop()
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.message_id = message_id
|
||||
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id)
|
||||
self.flask_app = current_app._get_current_object()
|
||||
if trace_manager_timer is None:
|
||||
self.start_timer()
|
||||
|
||||
def add_trace_task(self, trace_task: TraceTask):
|
||||
self.queue.put(trace_task)
|
||||
global trace_manager_timer
|
||||
global trace_manager_queue
|
||||
try:
|
||||
if self.trace_instance:
|
||||
trace_manager_queue.put(trace_task)
|
||||
except Exception as e:
|
||||
logging.debug(f"Error adding trace task: {e}")
|
||||
finally:
|
||||
self.start_timer()
|
||||
|
||||
def collect_tasks(self):
|
||||
global trace_manager_queue
|
||||
tasks = []
|
||||
while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
|
||||
task = trace_manager_queue.get_nowait()
|
||||
tasks.append(task)
|
||||
trace_manager_queue.task_done()
|
||||
return tasks
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
tasks = self.collect_tasks()
|
||||
if tasks:
|
||||
self.send_to_celery(tasks)
|
||||
except Exception as e:
|
||||
logging.debug(f"Error processing trace tasks: {e}")
|
||||
|
||||
def start_timer(self):
|
||||
global trace_manager_timer
|
||||
if trace_manager_timer is None or not trace_manager_timer.is_alive():
|
||||
trace_manager_timer = threading.Timer(
|
||||
trace_manager_interval, self.run
|
||||
)
|
||||
trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
|
||||
trace_manager_timer.daemon = False
|
||||
trace_manager_timer.start()
|
||||
|
||||
def send_to_celery(self, tasks: list[TraceTask]):
|
||||
with self.flask_app.app_context():
|
||||
for task in tasks:
|
||||
trace_info = task.execute()
|
||||
task_data = {
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"message_id": self.message_id,
|
||||
"trace_info_type": type(trace_info).__name__,
|
||||
"trace_info": trace_info.model_dump() if trace_info else {},
|
||||
}
|
||||
process_trace_tasks.delay(task_data)
|
||||
|
Reference in New Issue
Block a user