Feat/fix ops trace (#5672)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Joe
2024-06-28 00:24:37 +08:00
committed by GitHub
parent f0ea540b34
commit e8b8f6c6dd
17 changed files with 372 additions and 64 deletions

View File

@@ -94,5 +94,15 @@ class ToolTraceInfo(BaseTraceInfo):
class GenerateNameTraceInfo(BaseTraceInfo):
conversation_id: str
conversation_id: Optional[str] = None
tenant_id: str
trace_info_info_map = {
'WorkflowTraceInfo': WorkflowTraceInfo,
'MessageTraceInfo': MessageTraceInfo,
'ModerationTraceInfo': ModerationTraceInfo,
'SuggestedQuestionTraceInfo': SuggestedQuestionTraceInfo,
'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo,
'ToolTraceInfo': ToolTraceInfo,
'GenerateNameTraceInfo': GenerateNameTraceInfo,
}

View File

@@ -147,6 +147,7 @@ class LangFuseDataTrace(BaseTraceInstance):
# add span
if trace_info.message_id:
span_data = LangfuseSpan(
id=node_execution_id,
name=f"{node_name}_{node_execution_id}",
input=inputs,
output=outputs,
@@ -160,6 +161,7 @@ class LangFuseDataTrace(BaseTraceInstance):
)
else:
span_data = LangfuseSpan(
id=node_execution_id,
name=f"{node_name}_{node_execution_id}",
input=inputs,
output=outputs,
@@ -173,6 +175,30 @@ class LangFuseDataTrace(BaseTraceInstance):
self.add_span(langfuse_span_data=span_data)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
if process_data and process_data.get("model_mode") == "chat":
total_token = metadata.get("total_tokens", 0)
# add generation
generation_usage = GenerationUsage(
totalTokens=total_token,
)
node_generation_data = LangfuseGeneration(
name=f"generation_{node_execution_id}",
trace_id=trace_id,
parent_observation_id=node_execution_id,
start_time=created_at,
end_time=finished_at,
input=inputs,
output=outputs,
metadata=metadata,
level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR,
status_message=trace_info.error if trace_info.error else "",
usage=generation_usage,
)
self.add_generation(langfuse_generation_data=node_generation_data)
def message_trace(
self, trace_info: MessageTraceInfo, **kwargs
):
@@ -186,7 +212,7 @@ class LangFuseDataTrace(BaseTraceInstance):
if message_data.from_end_user_id:
end_user_data: EndUser = db.session.query(EndUser).filter(
EndUser.id == message_data.from_end_user_id
).first().session_id
).first()
user_id = end_user_data.session_id
trace_data = LangfuseTrace(
@@ -220,6 +246,7 @@ class LangFuseDataTrace(BaseTraceInstance):
output=trace_info.answer_tokens,
total=trace_info.total_tokens,
unit=UnitEnum.TOKENS,
totalCost=message_data.total_price,
)
langfuse_generation_data = LangfuseGeneration(
@@ -303,7 +330,7 @@ class LangFuseDataTrace(BaseTraceInstance):
start_time=trace_info.start_time,
end_time=trace_info.end_time,
metadata=trace_info.metadata,
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
level=LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR,
status_message=trace_info.error,
)

View File

@@ -1,16 +1,17 @@
import json
import logging
import os
import queue
import threading
import time
from datetime import timedelta
from enum import Enum
from typing import Any, Optional, Union
from uuid import UUID
from flask import Flask, current_app
from flask import current_app
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import (
LangfuseConfig,
LangSmithConfig,
@@ -31,6 +32,7 @@ from core.ops.utils import get_message_data
from extensions.ext_database import db
from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog, WorkflowRun
from tasks.ops_trace_task import process_trace_tasks
provider_config_map = {
TracingProviderEnum.LANGFUSE.value: {
@@ -105,7 +107,7 @@ class OpsTraceManager:
return config_class(**new_config).model_dump()
@classmethod
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config:dict):
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
"""
Decrypt tracing config
:param tracing_provider: tracing provider
@@ -295,11 +297,9 @@ class TraceTask:
self.kwargs = kwargs
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
def execute(self, trace_instance: BaseTraceInstance):
def execute(self):
method_name, trace_info = self.preprocess()
if trace_instance:
method = trace_instance.trace
method(trace_info)
return trace_info
def preprocess(self):
if self.trace_type == TraceTaskName.CONVERSATION_TRACE:
@@ -372,7 +372,7 @@ class TraceTask:
}
workflow_trace_info = WorkflowTraceInfo(
workflow_data=workflow_run,
workflow_data=workflow_run.to_dict(),
conversation_id=conversation_id,
workflow_id=workflow_id,
tenant_id=tenant_id,
@@ -427,7 +427,8 @@ class TraceTask:
message_tokens = message_data.message_tokens
message_trace_info = MessageTraceInfo(
message_data=message_data,
message_id=message_id,
message_data=message_data.to_dict(),
conversation_model=conversation_mode,
message_tokens=message_tokens,
answer_tokens=message_data.answer_tokens,
@@ -469,7 +470,7 @@ class TraceTask:
moderation_trace_info = ModerationTraceInfo(
message_id=workflow_app_log_id if workflow_app_log_id else message_id,
inputs=inputs,
message_data=message_data,
message_data=message_data.to_dict(),
flagged=moderation_result.flagged,
action=moderation_result.action,
preset_response=moderation_result.preset_response,
@@ -508,7 +509,7 @@ class TraceTask:
suggested_question_trace_info = SuggestedQuestionTraceInfo(
message_id=workflow_app_log_id if workflow_app_log_id else message_id,
message_data=message_data,
message_data=message_data.to_dict(),
inputs=message_data.message,
outputs=message_data.answer,
start_time=timer.get("start"),
@@ -550,11 +551,11 @@ class TraceTask:
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
message_id=message_id,
inputs=message_data.query if message_data.query else message_data.inputs,
documents=documents,
documents=[doc.model_dump() for doc in documents],
start_time=timer.get("start"),
end_time=timer.get("end"),
metadata=metadata,
message_data=message_data,
message_data=message_data.to_dict(),
)
return dataset_retrieval_trace_info
@@ -613,7 +614,7 @@ class TraceTask:
tool_trace_info = ToolTraceInfo(
message_id=message_id,
message_data=message_data,
message_data=message_data.to_dict(),
tool_name=tool_name,
start_time=timer.get("start") if timer else created_time,
end_time=timer.get("end") if timer else end_time,
@@ -657,31 +658,71 @@ class TraceTask:
return generate_name_trace_info
trace_manager_timer = None
trace_manager_queue = queue.Queue()
trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 1))
trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
class TraceQueueManager:
def __init__(self, app_id=None, conversation_id=None, message_id=None):
tracing_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id)
self.queue = queue.Queue()
self.is_running = True
self.thread = threading.Thread(
target=self.process_queue, kwargs={
'flask_app': current_app._get_current_object(),
'trace_instance': tracing_instance
}
)
self.thread.start()
global trace_manager_timer
def stop(self):
self.is_running = False
def process_queue(self, flask_app: Flask, trace_instance: BaseTraceInstance):
with flask_app.app_context():
while self.is_running:
try:
task = self.queue.get(timeout=60)
task.execute(trace_instance)
self.queue.task_done()
except queue.Empty:
self.stop()
self.app_id = app_id
self.conversation_id = conversation_id
self.message_id = message_id
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id)
self.flask_app = current_app._get_current_object()
if trace_manager_timer is None:
self.start_timer()
def add_trace_task(self, trace_task: TraceTask):
self.queue.put(trace_task)
global trace_manager_timer
global trace_manager_queue
try:
if self.trace_instance:
trace_manager_queue.put(trace_task)
except Exception as e:
logging.debug(f"Error adding trace task: {e}")
finally:
self.start_timer()
def collect_tasks(self):
global trace_manager_queue
tasks = []
while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
task = trace_manager_queue.get_nowait()
tasks.append(task)
trace_manager_queue.task_done()
return tasks
def run(self):
try:
tasks = self.collect_tasks()
if tasks:
self.send_to_celery(tasks)
except Exception as e:
logging.debug(f"Error processing trace tasks: {e}")
def start_timer(self):
global trace_manager_timer
if trace_manager_timer is None or not trace_manager_timer.is_alive():
trace_manager_timer = threading.Timer(
trace_manager_interval, self.run
)
trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
trace_manager_timer.daemon = False
trace_manager_timer.start()
def send_to_celery(self, tasks: list[TraceTask]):
with self.flask_app.app_context():
for task in tasks:
trace_info = task.execute()
task_data = {
"app_id": self.app_id,
"conversation_id": self.conversation_id,
"message_id": self.message_id,
"trace_info_type": type(trace_info).__name__,
"trace_info": trace_info.model_dump() if trace_info else {},
}
process_trace_tasks.delay(task_data)