diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 4913e33b..3c701ab2 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -8,8 +8,10 @@ import logging import requests from datetime import datetime -from typing import Any, Dict, List, Optional, Protocol +import ast +import json import os +from typing import Any, Dict, List, Optional, Protocol from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message from .base import BaseAdapter @@ -44,6 +46,43 @@ def __call__( ... +def extract_openai_response(observations: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """Attempt to extract and parse attributes from raw_gen_ai_request observation. This only works when stored in OTEL format. + + Args: + observations: List of observation dictionaries from the trace + + Returns: + Dict with all attributes parsed. Or None if not found. + """ + for obs in observations: + if obs.get("name") == "raw_gen_ai_request" and obs.get("type") == "SPAN": + metadata = obs.get("metadata") or {} + attributes = metadata.get("attributes") or {} + + result: Dict[str, Any] = {} + + for key, value in attributes.items(): + # Try to parse stringified objects (could be Python repr or JSON) + if isinstance(value, str) and value.startswith(("[", "{")): + try: + result[key] = ast.literal_eval(value) + except Exception as e: + logger.debug("Failed to parse %s with ast.literal_eval: %s", key, e) + try: + result[key] = json.loads(value) + except Exception as e: + logger.debug("Failed to parse %s with json.loads: %s", key, e) + result[key] = value + else: + result[key] = value + + if result: + return result + + return None + + def convert_trace_dict_to_evaluation_row( trace: Dict[str, Any], include_tool_calls: bool = True, span_name: Optional[str] = None ) -> Optional[EvaluationRow]: @@ -96,6 +135,14 @@ def convert_trace_dict_to_evaluation_row( ): break # Break early if we've found all the metadata we need + observations = trace.get("observations") or [] + # We can only extract when stored in OTEL format. + openai_response = extract_openai_response(observations) + if openai_response: + choices = openai_response.get("llm.openai.choices") + if choices and len(choices) > 0: + execution_metadata.finish_reason = choices[0].get("finish_reason") + return EvaluationRow( messages=messages, tools=tools, @@ -160,7 +207,7 @@ def extract_messages_from_trace_dict( # Fallback: use the last GENERATION observation which typically contains full chat history if not messages: try: - all_observations = trace.get("observations", []) + all_observations = trace.get("observations") or [] gens = [obs for obs in all_observations if obs.get("type") == "GENERATION"] if gens: gens.sort(key=lambda x: x.get("start_time", "")) @@ -186,7 +233,7 @@ def get_final_generation_in_span_dict(trace: Dict[str, Any], span_name: str) -> The final generation dictionary, or None if not found """ # Get all observations from the trace - all_observations = trace.get("observations", []) + all_observations = trace.get("observations") or [] # Find a span with the given name that has generation children parent_span = None diff --git a/eval_protocol/proxy/proxy_core/langfuse.py b/eval_protocol/proxy/proxy_core/langfuse.py index d91da681..ec0e9475 100644 --- a/eval_protocol/proxy/proxy_core/langfuse.py +++ b/eval_protocol/proxy/proxy_core/langfuse.py @@ -50,6 +50,7 @@ def _serialize_trace_to_dict(trace_full: Any) -> Dict[str, Any]: "input": getattr(obs, "input", None), "output": getattr(obs, "output", None), "parent_observation_id": getattr(obs, "parent_observation_id", None), + "metadata": getattr(obs, "metadata", None), } for obs in getattr(trace_full, "observations", []) ] diff --git a/eval_protocol/reward_function.py b/eval_protocol/reward_function.py index 6bd11974..743d3c7c 100644 --- a/eval_protocol/reward_function.py +++ b/eval_protocol/reward_function.py @@ -12,7 +12,6 @@ from .models import EvaluateResult, MetricResult from .typed_interface import reward_function -logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) T = TypeVar("T", bound=Callable[..., EvaluateResult]) diff --git a/tests/remote_server/remote_server.py b/tests/remote_server/remote_server.py index 4ac4fd6c..c7655671 100644 --- a/tests/remote_server/remote_server.py +++ b/tests/remote_server/remote_server.py @@ -13,6 +13,9 @@ app = FastAPI() +# Configure logging for the remote server (required for INFO-level logs to be emitted) +logging.basicConfig(level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s") + # Attach Fireworks tracing handler to root logger fireworks_handler = FireworksTracingHttpHandler() logging.getLogger().addHandler(fireworks_handler) diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index 43da29ed..b196cb34 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -1,5 +1,6 @@ # AUTO SERVER STARTUP: Server is automatically started and stopped by the test +import logging import subprocess import socket import time @@ -19,10 +20,23 @@ ROLLOUT_IDS = set() +class StatusLogCaptureHandler(logging.Handler): + """Custom handler to capture status log messages.""" + + def __init__(self): + super().__init__() + self.status_100_messages: List[str] = [] + + def emit(self, record): + msg = record.getMessage() # Use getMessage(), not .message attribute + if "Found Fireworks log" in msg and "with status code 100" in msg: + self.status_100_messages.append(msg) + + @pytest.fixture(autouse=True) def check_rollout_coverage(monkeypatch): """ - Ensure we attempted to fetch remote traces for each rollout. + Ensure we attempted to fetch remote traces for each rollout and received status logs. This wraps the built-in default_fireworks_output_data_loader (without making it configurable) and tracks rollout_ids passed through its DataLoaderConfig. @@ -37,9 +51,32 @@ def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader: return original_loader(config) monkeypatch.setattr(remote_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader) + + # Add custom handler to capture status logs + status_handler = StatusLogCaptureHandler() + status_handler.setLevel(logging.INFO) + rrp_logger = logging.getLogger("eval_protocol.pytest.remote_rollout_processor") + rrp_logger.addHandler(status_handler) + # Ensure the logger level allows INFO messages through + original_level = rrp_logger.level + rrp_logger.setLevel(logging.INFO) + yield + + # Cleanup handler and restore level + rrp_logger.removeHandler(status_handler) + rrp_logger.setLevel(original_level) + + # After test completes, verify we saw status logs for all 3 rollouts assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" + # Check that we received "Found Fireworks log ... with status code 100" for each rollout + assert len(status_handler.status_100_messages) == 3, ( + f"Expected 3 'Found Fireworks log ... with status code 100' messages, but found {len(status_handler.status_100_messages)}. " + f"This means the status logs from the remote server were not received. " + f"Messages captured: {status_handler.status_100_messages}" + ) + def find_available_port() -> int: """Find an available port on localhost""" @@ -141,4 +178,8 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat assert "data_loader_type" in row.input_metadata.dataset_info assert "data_loader_num_rows" in row.input_metadata.dataset_info + assert row.execution_metadata.finish_reason == "stop", ( + f"Expected finish_reason='stop', got {row.execution_metadata.finish_reason}" + ) + return row