Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions eval_protocol/adapters/fireworks_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import logging
import requests
from datetime import datetime
from typing import Any, Dict, List, Optional, Protocol
import ast
import json
import os
from typing import Any, Dict, List, Optional, Protocol

from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
from .base import BaseAdapter
Expand Down Expand Up @@ -44,6 +46,43 @@ def __call__(
...


def extract_openai_response(observations: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Attempt to extract and parse attributes from raw_gen_ai_request observation. This only works when stored in OTEL format.

Args:
observations: List of observation dictionaries from the trace

Returns:
Dict with all attributes parsed. Or None if not found.
"""
for obs in observations:
if obs.get("name") == "raw_gen_ai_request" and obs.get("type") == "SPAN":
metadata = obs.get("metadata") or {}
attributes = metadata.get("attributes") or {}

result: Dict[str, Any] = {}

for key, value in attributes.items():
# Try to parse stringified objects (could be Python repr or JSON)
if isinstance(value, str) and value.startswith(("[", "{")):
try:
result[key] = ast.literal_eval(value)
except Exception as e:
logger.debug("Failed to parse %s with ast.literal_eval: %s", key, e)
try:
result[key] = json.loads(value)
except Exception as e:
logger.debug("Failed to parse %s with json.loads: %s", key, e)
result[key] = value
else:
result[key] = value

if result:
return result

return None


def convert_trace_dict_to_evaluation_row(
trace: Dict[str, Any], include_tool_calls: bool = True, span_name: Optional[str] = None
) -> Optional[EvaluationRow]:
Expand Down Expand Up @@ -96,6 +135,14 @@ def convert_trace_dict_to_evaluation_row(
):
break # Break early if we've found all the metadata we need

observations = trace.get("observations") or []
# We can only extract when stored in OTEL format.
openai_response = extract_openai_response(observations)
if openai_response:
choices = openai_response.get("llm.openai.choices")
if choices and len(choices) > 0:
execution_metadata.finish_reason = choices[0].get("finish_reason")

return EvaluationRow(
messages=messages,
tools=tools,
Expand Down Expand Up @@ -160,7 +207,7 @@ def extract_messages_from_trace_dict(
# Fallback: use the last GENERATION observation which typically contains full chat history
if not messages:
try:
all_observations = trace.get("observations", [])
all_observations = trace.get("observations") or []
gens = [obs for obs in all_observations if obs.get("type") == "GENERATION"]
if gens:
gens.sort(key=lambda x: x.get("start_time", ""))
Expand All @@ -186,7 +233,7 @@ def get_final_generation_in_span_dict(trace: Dict[str, Any], span_name: str) ->
The final generation dictionary, or None if not found
"""
# Get all observations from the trace
all_observations = trace.get("observations", [])
all_observations = trace.get("observations") or []

# Find a span with the given name that has generation children
parent_span = None
Expand Down
1 change: 1 addition & 0 deletions eval_protocol/proxy/proxy_core/langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _serialize_trace_to_dict(trace_full: Any) -> Dict[str, Any]:
"input": getattr(obs, "input", None),
"output": getattr(obs, "output", None),
"parent_observation_id": getattr(obs, "parent_observation_id", None),
"metadata": getattr(obs, "metadata", None),
}
for obs in getattr(trace_full, "observations", [])
]
Expand Down
1 change: 0 additions & 1 deletion eval_protocol/reward_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .models import EvaluateResult, MetricResult
from .typed_interface import reward_function

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

T = TypeVar("T", bound=Callable[..., EvaluateResult])
Expand Down
3 changes: 3 additions & 0 deletions tests/remote_server/remote_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

app = FastAPI()

# Configure logging for the remote server (required for INFO-level logs to be emitted)
logging.basicConfig(level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s")

# Attach Fireworks tracing handler to root logger
fireworks_handler = FireworksTracingHttpHandler()
logging.getLogger().addHandler(fireworks_handler)
Expand Down
43 changes: 42 additions & 1 deletion tests/remote_server/test_remote_fireworks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# AUTO SERVER STARTUP: Server is automatically started and stopped by the test

import logging
import subprocess
import socket
import time
Expand All @@ -19,10 +20,23 @@
ROLLOUT_IDS = set()


class StatusLogCaptureHandler(logging.Handler):
"""Custom handler to capture status log messages."""

def __init__(self):
super().__init__()
self.status_100_messages: List[str] = []

def emit(self, record):
msg = record.getMessage() # Use getMessage(), not .message attribute
if "Found Fireworks log" in msg and "with status code 100" in msg:
self.status_100_messages.append(msg)


@pytest.fixture(autouse=True)
def check_rollout_coverage(monkeypatch):
"""
Ensure we attempted to fetch remote traces for each rollout.
Ensure we attempted to fetch remote traces for each rollout and received status logs.

This wraps the built-in default_fireworks_output_data_loader (without making it configurable)
and tracks rollout_ids passed through its DataLoaderConfig.
Expand All @@ -37,9 +51,32 @@ def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader:
return original_loader(config)

monkeypatch.setattr(remote_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader)

# Add custom handler to capture status logs
status_handler = StatusLogCaptureHandler()
status_handler.setLevel(logging.INFO)
rrp_logger = logging.getLogger("eval_protocol.pytest.remote_rollout_processor")
rrp_logger.addHandler(status_handler)
# Ensure the logger level allows INFO messages through
original_level = rrp_logger.level
rrp_logger.setLevel(logging.INFO)

yield

# Cleanup handler and restore level
rrp_logger.removeHandler(status_handler)
rrp_logger.setLevel(original_level)

# After test completes, verify we saw status logs for all 3 rollouts
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"

# Check that we received "Found Fireworks log ... with status code 100" for each rollout
assert len(status_handler.status_100_messages) == 3, (
f"Expected 3 'Found Fireworks log ... with status code 100' messages, but found {len(status_handler.status_100_messages)}. "
f"This means the status logs from the remote server were not received. "
f"Messages captured: {status_handler.status_100_messages}"
)


def find_available_port() -> int:
"""Find an available port on localhost"""
Expand Down Expand Up @@ -141,4 +178,8 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat
assert "data_loader_type" in row.input_metadata.dataset_info
assert "data_loader_num_rows" in row.input_metadata.dataset_info

assert row.execution_metadata.finish_reason == "stop", (
f"Expected finish_reason='stop', got {row.execution_metadata.finish_reason}"
)

return row
Loading