diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index ffda0885d..8b3ce0c44 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -3,7 +3,7 @@ import json import logging import os from datetime import datetime, timedelta -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes from opentelemetry import trace @@ -142,11 +142,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): raise def workflow_trace(self, trace_info: WorkflowTraceInfo): - if trace_info.message_data is None: - return - workflow_metadata = { - "workflow_id": trace_info.workflow_run_id or "", + "workflow_run_id": trace_info.workflow_run_id or "", "message_id": trace_info.message_id or "", "workflow_app_log_id": trace_info.workflow_app_log_id or "", "status": trace_info.workflow_run_status or "", @@ -156,7 +153,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } workflow_metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = uuid_to_trace_id(trace_info.workflow_run_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -213,7 +210,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): if model: node_metadata["ls_model_name"] = model - outputs = json.loads(node_execution.outputs).get("usage", {}) + outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {} usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) if usage_data: node_metadata["total_tokens"] = usage_data.get("total_tokens", 0) @@ -236,31 +233,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, start_time=datetime_to_nanos(created_at), + context=trace.set_span_in_context(trace.NonRecordingSpan(context)), ) try: if node_execution.node_type == "llm": + llm_attributes: dict[str, Any] = { + SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False), + } provider = process_data.get("model_provider") model = process_data.get("model_name") if provider: - node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider) + llm_attributes[SpanAttributes.LLM_PROVIDER] = provider if model: - node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model) - - outputs = json.loads(node_execution.outputs).get("usage", {}) + llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model + outputs = ( + json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {} + ) usage_data = ( process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) ) if usage_data: - node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0) - ) - node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0) - ) - node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0) + llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0) + llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = usage_data.get("prompt_tokens", 0) + llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = usage_data.get( + "completion_tokens", 0 ) + llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", []))) + node_span.set_attributes(llm_attributes) finally: node_span.end(end_time=datetime_to_nanos(finished_at)) finally: @@ -352,25 +352,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False), SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, } - - if isinstance(trace_info.inputs, list): - for i, msg in enumerate(trace_info.inputs): - if isinstance(msg, dict): - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "") - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get( - "role", "user" - ) - # todo: handle assistant and tool role messages, as they don't always - # have a text field, but may have a tool_calls field instead - # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58', - # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]} - elif isinstance(trace_info.inputs, dict): - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs) - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" - elif isinstance(trace_info.inputs, str): - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" - + llm_attributes.update(self._construct_llm_attributes(trace_info.inputs)) if trace_info.total_tokens is not None and trace_info.total_tokens > 0: llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens if trace_info.message_tokens is not None and trace_info.message_tokens > 0: @@ -724,3 +706,24 @@ class ArizePhoenixDataTrace(BaseTraceInstance): .all() ) return workflow_nodes + + def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: + """Helper method to construct LLM attributes with passed prompts.""" + attributes = {} + if isinstance(prompts, list): + for i, msg in enumerate(prompts): + if isinstance(msg, dict): + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "") + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user") + # todo: handle assistant and tool role messages, as they don't always + # have a text field, but may have a tool_calls field instead + # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58', + # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]} + elif isinstance(prompts, dict): + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts) + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" + elif isinstance(prompts, str): + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" + + return attributes