fix: correct tracing for workflows and chatflows for phoenix (#22547)
This commit is contained in:
@@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry import trace
|
||||
@@ -142,11 +142,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
workflow_metadata = {
|
||||
"workflow_id": trace_info.workflow_run_id or "",
|
||||
"workflow_run_id": trace_info.workflow_run_id or "",
|
||||
"message_id": trace_info.message_id or "",
|
||||
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
|
||||
"status": trace_info.workflow_run_status or "",
|
||||
@@ -156,7 +153,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
workflow_metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
trace_id = uuid_to_trace_id(trace_info.workflow_run_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
@@ -213,7 +210,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
if model:
|
||||
node_metadata["ls_model_name"] = model
|
||||
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {})
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
if usage_data:
|
||||
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
|
||||
@@ -236,31 +233,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(created_at),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if node_execution.node_type == "llm":
|
||||
llm_attributes: dict[str, Any] = {
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
}
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider)
|
||||
llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
|
||||
if model:
|
||||
node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model)
|
||||
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {})
|
||||
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
|
||||
outputs = (
|
||||
json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
|
||||
)
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
)
|
||||
if usage_data:
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0)
|
||||
)
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0)
|
||||
)
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0)
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = usage_data.get("prompt_tokens", 0)
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = usage_data.get(
|
||||
"completion_tokens", 0
|
||||
)
|
||||
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
|
||||
node_span.set_attributes(llm_attributes)
|
||||
finally:
|
||||
node_span.end(end_time=datetime_to_nanos(finished_at))
|
||||
finally:
|
||||
@@ -352,25 +352,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
if isinstance(trace_info.inputs, list):
|
||||
for i, msg in enumerate(trace_info.inputs):
|
||||
if isinstance(msg, dict):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get(
|
||||
"role", "user"
|
||||
)
|
||||
# todo: handle assistant and tool role messages, as they don't always
|
||||
# have a text field, but may have a tool_calls field instead
|
||||
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
|
||||
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
|
||||
elif isinstance(trace_info.inputs, dict):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs)
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
elif isinstance(trace_info.inputs, str):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
|
||||
llm_attributes.update(self._construct_llm_attributes(trace_info.inputs))
|
||||
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
|
||||
if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
|
||||
@@ -724,3 +706,24 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
.all()
|
||||
)
|
||||
return workflow_nodes
|
||||
|
||||
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
|
||||
"""Helper method to construct LLM attributes with passed prompts."""
|
||||
attributes = {}
|
||||
if isinstance(prompts, list):
|
||||
for i, msg in enumerate(prompts):
|
||||
if isinstance(msg, dict):
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user")
|
||||
# todo: handle assistant and tool role messages, as they don't always
|
||||
# have a text field, but may have a tool_calls field instead
|
||||
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
|
||||
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
|
||||
elif isinstance(prompts, dict):
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts)
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
elif isinstance(prompts, str):
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts
|
||||
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
|
||||
return attributes
|
||||
|
Reference in New Issue
Block a user