From fd9725e88b117eb0f64e99040bf251ecf9310c2a Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 26 Jan 2026 14:08:47 -0800 Subject: [PATCH 1/6] add finish reason --- eval_protocol/adapters/fireworks_tracing.py | 43 +++- eval_protocol/proxy/proxy_core/langfuse.py | 1 + scripts/fetch_traces_test.py | 219 ++++++++++++++++++++ 3 files changed, 262 insertions(+), 1 deletion(-) create mode 100644 scripts/fetch_traces_test.py diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 4913e33b..4007f72b 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -8,8 +8,9 @@ import logging import requests from datetime import datetime -from typing import Any, Dict, List, Optional, Protocol +import ast import os +from typing import Any, Dict, List, Optional, Protocol from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message from .base import BaseAdapter @@ -44,6 +45,38 @@ def __call__( ... +def extract_openai_response(observations: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """Attempt to extract and parse attributes from raw_gen_ai_request observation. This only works when stored in OTEL format. + + Args: + observations: List of observation dictionaries from the trace + + Returns: + Dict with all attributes parsed. Or None if not found. + """ + for obs in observations: + if obs.get("name") == "raw_gen_ai_request" and obs.get("type") == "SPAN": + metadata = obs.get("metadata", {}) + attributes = metadata.get("attributes", {}) + + result: Dict[str, Any] = {} + + for key, value in attributes.items(): + # Try to parse stringified Python literals, otherwise keep as-is + if isinstance(value, str) and value.startswith(("[", "{")): + try: + result[key] = ast.literal_eval(value) + except Exception: + result[key] = value + else: + result[key] = value + + if result: + return result + + return None + + def convert_trace_dict_to_evaluation_row( trace: Dict[str, Any], include_tool_calls: bool = True, span_name: Optional[str] = None ) -> Optional[EvaluationRow]: @@ -96,6 +129,14 @@ def convert_trace_dict_to_evaluation_row( ): break # Break early if we've found all the metadata we need + observations = trace.get("observations", []) + # We can only extract when stored in OTEL format. + openai_response = extract_openai_response(observations) + if openai_response: + choices = openai_response.get("llm.openai.choices") + if choices and len(choices) > 0: + execution_metadata.finish_reason = choices[0].get("finish_reason") + return EvaluationRow( messages=messages, tools=tools, diff --git a/eval_protocol/proxy/proxy_core/langfuse.py b/eval_protocol/proxy/proxy_core/langfuse.py index d91da681..ec0e9475 100644 --- a/eval_protocol/proxy/proxy_core/langfuse.py +++ b/eval_protocol/proxy/proxy_core/langfuse.py @@ -50,6 +50,7 @@ def _serialize_trace_to_dict(trace_full: Any) -> Dict[str, Any]: "input": getattr(obs, "input", None), "output": getattr(obs, "output", None), "parent_observation_id": getattr(obs, "parent_observation_id", None), + "metadata": getattr(obs, "metadata", None), } for obs in getattr(trace_full, "observations", []) ] diff --git a/scripts/fetch_traces_test.py b/scripts/fetch_traces_test.py new file mode 100644 index 00000000..2ca62609 --- /dev/null +++ b/scripts/fetch_traces_test.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +"""Simple script to fetch traces directly from Langfuse and parse them. + +This bypasses the Fireworks tracing proxy (and its Redis insertion_id check) +by querying Langfuse directly. + +Required env vars: + LANGFUSE_PUBLIC_KEY - Your Langfuse public key + LANGFUSE_SECRET_KEY - Your Langfuse secret key + LANGFUSE_HOST - Langfuse host (default: https://cloud.langfuse.com) + ROLLOUT_ID - The rollout_id to search for (default: test-test-test) +""" + +import json +import os +from datetime import datetime, timedelta +from typing import List, Dict, Any + +from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row +from eval_protocol.models import EvaluationRow + + +os.environ.setdefault("LANGFUSE_PUBLIC_KEY", "pk-lf-9470ba98-7ace-4fe0-b1dc-3dda0f66d812") +os.environ.setdefault("LANGFUSE_SECRET_KEY", "sk-lf-36b11237-a230-4524-a6e0-3af372b6f5b6") +os.environ.setdefault("LANGFUSE_HOST", "https://langfuse-prod.fireworks.ai") # EU region + + +def fetch_traces_from_langfuse( + tags: List[str], + limit: int = 100, + hours_back: int = 24, +) -> List[Dict[str, Any]]: + """Fetch traces directly from Langfuse (bypassing Fireworks proxy). + + This avoids the Redis insertion_id check by going straight to Langfuse. + """ + try: + from langfuse import Langfuse + except ImportError: + print("ERROR: langfuse not installed. Run: pip install langfuse") + return [] + + # Get Langfuse credentials from environment + public_key = os.environ.get("LANGFUSE_PUBLIC_KEY") + secret_key = os.environ.get("LANGFUSE_SECRET_KEY") + host = os.environ.get("LANGFUSE_HOST", "https://cloud.langfuse.com") + + if not public_key or not secret_key: + print("ERROR: LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY must be set") + return [] + + print(f"Connecting to Langfuse at {host}...") + client = Langfuse(public_key=public_key, secret_key=secret_key, host=host) + + # Calculate time range + to_ts = datetime.now() + from_ts = to_ts - timedelta(hours=hours_back) + + print(f"Fetching traces with tags: {tags}") + print(f"Time range: {from_ts} to {to_ts}") + + # Fetch trace list + traces_response = client.api.trace.list( + page=1, + limit=limit, + tags=tags, + from_timestamp=from_ts, + to_timestamp=to_ts, + order_by="timestamp.desc", + ) + + if not traces_response or not traces_response.data: + print("No traces found in list response") + return [] + + print(f"Found {len(traces_response.data)} trace summaries") + + # Fetch full trace details and serialize to dict + traces: List[Dict[str, Any]] = [] + for trace_info in traces_response.data: + try: + trace_full = client.api.trace.get(trace_info.id) + + # Serialize to dict (same format as proxy returns) + trace_dict = _serialize_trace_to_dict(trace_full) + traces.append(trace_dict) + + except Exception as e: + print(f" Failed to fetch trace {trace_info.id}: {e}") + + print(f"Successfully fetched {len(traces)} full traces") + return traces + + +def _serialize_trace_to_dict(trace_full: Any) -> Dict[str, Any]: + """Convert Langfuse trace object to dict format (same as proxy does).""" + timestamp = getattr(trace_full, "timestamp", None) + + return { + "id": trace_full.id, + "name": getattr(trace_full, "name", None), + "user_id": getattr(trace_full, "user_id", None), + "session_id": getattr(trace_full, "session_id", None), + "tags": getattr(trace_full, "tags", []), + "timestamp": str(timestamp) if timestamp else None, + "input": getattr(trace_full, "input", None), + "output": getattr(trace_full, "output", None), + "metadata": getattr(trace_full, "metadata", None), + "observations": [ + { + "id": obs.id, + "type": getattr(obs, "type", None), + "name": getattr(obs, "name", None), + "start_time": str(getattr(obs, "start_time", None)) if getattr(obs, "start_time", None) else None, + "end_time": str(getattr(obs, "end_time", None)) if getattr(obs, "end_time", None) else None, + "input": getattr(obs, "input", None), + "output": getattr(obs, "output", None), + "parent_observation_id": getattr(obs, "parent_observation_id", None), + "metadata": getattr(obs, "metadata", None), + } + for obs in getattr(trace_full, "observations", []) + ] + if hasattr(trace_full, "observations") + else [], + } + + +def parse_traces_to_rows(traces: List[Dict[str, Any]], include_tool_calls: bool = True) -> List[EvaluationRow]: + """Parse raw trace dicts to EvaluationRows using the same logic as get_evaluation_rows.""" + rows = [] + for trace in traces: + try: + row = convert_trace_dict_to_evaluation_row(trace, include_tool_calls) + if row: + rows.append(row) + except Exception as e: + print(f" Failed to convert trace {trace.get('id')}: {e}") + return rows + + +def print_row_details(row: EvaluationRow, index: int): + """Print details of a single EvaluationRow.""" + print(f"\n--- Row {index + 1} ---") + print(f"Row ID: {row.input_metadata.row_id}") + print( + f"Trace ID: {row.input_metadata.session_data.get('langfuse_trace_id') if row.input_metadata.session_data else None}" + ) + print(f"Rollout ID: {row.execution_metadata.rollout_id}") + print(f"Invocation ID: {row.execution_metadata.invocation_id}") + print(f"Experiment ID: {row.execution_metadata.experiment_id}") + print(f"Run ID: {row.execution_metadata.run_id}") + print(f"Finish Reason: {row.execution_metadata.finish_reason}") # NEW + print(f"Num messages: {len(row.messages)}") + print(f"Tools: {row.tools is not None}") + + print("\nMessages:") + for j, msg in enumerate(row.messages): + content_preview = str(msg.content)[:100] if msg.content else "(empty)" + tool_calls_info = f" [tool_calls: {len(msg.tool_calls)}]" if msg.tool_calls else "" + print(f" [{j}] {msg.role}: {content_preview}{tool_calls_info}") + + +def main(): + rollout_id = os.environ.get("ROLLOUT_ID", "test-test-test") + hours_back = int(os.environ.get("HOURS_BACK", "24")) + + print(f"Rollout ID: {rollout_id}") + print(f"Hours back: {hours_back}") + print("=" * 60) + + # Step 1: Fetch raw traces directly from Langfuse + print("\n[1] Fetching raw traces from Langfuse...") + traces = fetch_traces_from_langfuse( + tags=[f"rollout_id:{rollout_id}"], + limit=10, + hours_back=hours_back, + ) + + if not traces: + print("\nNo traces found!") + return + + # Step 2: Print raw trace structure (first trace only) + print("\n[2] Raw trace structure (first trace):") + print("-" * 60) + first_trace = traces[0] + print(f"ID: {first_trace.get('id')}") + print(f"Name: {first_trace.get('name')}") + print(f"Tags: {first_trace.get('tags')}") + print(f"Input type: {type(first_trace.get('input'))}") + print(f"Input: {json.dumps(first_trace.get('input'), indent=2)[:500]}...") + print(f"Output type: {type(first_trace.get('output'))}") + print(f"Output: {json.dumps(first_trace.get('output'), indent=2)[:500] if first_trace.get('output') else None}...") + print(f"Num observations: {len(first_trace.get('observations', []))}") + + # Print observations + for obs in first_trace.get("observations", []): + print(f"\n Observation: {obs.get('name')} ({obs.get('type')})") + print(f" Input type: {type(obs.get('input'))}") + print(f" Input: {json.dumps(obs.get('input'), indent=2)[:300] if obs.get('input') else None}...") + print(f" Output type: {type(obs.get('output'))}") + print(f" Output: {json.dumps(obs.get('output'), indent=2)[:300] if obs.get('output') else None}...") + + # Step 3: Parse to EvaluationRows + print("\n[3] Parsing traces to EvaluationRows...") + print("-" * 60) + rows = parse_traces_to_rows(traces) + + print(f"\nSuccessfully parsed {len(rows)} / {len(traces)} traces") + + # Step 4: Print row details + print("\n[4] EvaluationRow details:") + print("=" * 60) + for i, row in enumerate(rows): + print_row_details(row, i) + + +if __name__ == "__main__": + main() From becb3072624047d82c717cb532a7c3482e71fb6f Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 26 Jan 2026 14:09:10 -0800 Subject: [PATCH 2/6] remove --- scripts/fetch_traces_test.py | 219 ----------------------------------- 1 file changed, 219 deletions(-) delete mode 100644 scripts/fetch_traces_test.py diff --git a/scripts/fetch_traces_test.py b/scripts/fetch_traces_test.py deleted file mode 100644 index 2ca62609..00000000 --- a/scripts/fetch_traces_test.py +++ /dev/null @@ -1,219 +0,0 @@ -#!/usr/bin/env python3 -"""Simple script to fetch traces directly from Langfuse and parse them. - -This bypasses the Fireworks tracing proxy (and its Redis insertion_id check) -by querying Langfuse directly. - -Required env vars: - LANGFUSE_PUBLIC_KEY - Your Langfuse public key - LANGFUSE_SECRET_KEY - Your Langfuse secret key - LANGFUSE_HOST - Langfuse host (default: https://cloud.langfuse.com) - ROLLOUT_ID - The rollout_id to search for (default: test-test-test) -""" - -import json -import os -from datetime import datetime, timedelta -from typing import List, Dict, Any - -from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row -from eval_protocol.models import EvaluationRow - - -os.environ.setdefault("LANGFUSE_PUBLIC_KEY", "pk-lf-9470ba98-7ace-4fe0-b1dc-3dda0f66d812") -os.environ.setdefault("LANGFUSE_SECRET_KEY", "sk-lf-36b11237-a230-4524-a6e0-3af372b6f5b6") -os.environ.setdefault("LANGFUSE_HOST", "https://langfuse-prod.fireworks.ai") # EU region - - -def fetch_traces_from_langfuse( - tags: List[str], - limit: int = 100, - hours_back: int = 24, -) -> List[Dict[str, Any]]: - """Fetch traces directly from Langfuse (bypassing Fireworks proxy). - - This avoids the Redis insertion_id check by going straight to Langfuse. - """ - try: - from langfuse import Langfuse - except ImportError: - print("ERROR: langfuse not installed. Run: pip install langfuse") - return [] - - # Get Langfuse credentials from environment - public_key = os.environ.get("LANGFUSE_PUBLIC_KEY") - secret_key = os.environ.get("LANGFUSE_SECRET_KEY") - host = os.environ.get("LANGFUSE_HOST", "https://cloud.langfuse.com") - - if not public_key or not secret_key: - print("ERROR: LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY must be set") - return [] - - print(f"Connecting to Langfuse at {host}...") - client = Langfuse(public_key=public_key, secret_key=secret_key, host=host) - - # Calculate time range - to_ts = datetime.now() - from_ts = to_ts - timedelta(hours=hours_back) - - print(f"Fetching traces with tags: {tags}") - print(f"Time range: {from_ts} to {to_ts}") - - # Fetch trace list - traces_response = client.api.trace.list( - page=1, - limit=limit, - tags=tags, - from_timestamp=from_ts, - to_timestamp=to_ts, - order_by="timestamp.desc", - ) - - if not traces_response or not traces_response.data: - print("No traces found in list response") - return [] - - print(f"Found {len(traces_response.data)} trace summaries") - - # Fetch full trace details and serialize to dict - traces: List[Dict[str, Any]] = [] - for trace_info in traces_response.data: - try: - trace_full = client.api.trace.get(trace_info.id) - - # Serialize to dict (same format as proxy returns) - trace_dict = _serialize_trace_to_dict(trace_full) - traces.append(trace_dict) - - except Exception as e: - print(f" Failed to fetch trace {trace_info.id}: {e}") - - print(f"Successfully fetched {len(traces)} full traces") - return traces - - -def _serialize_trace_to_dict(trace_full: Any) -> Dict[str, Any]: - """Convert Langfuse trace object to dict format (same as proxy does).""" - timestamp = getattr(trace_full, "timestamp", None) - - return { - "id": trace_full.id, - "name": getattr(trace_full, "name", None), - "user_id": getattr(trace_full, "user_id", None), - "session_id": getattr(trace_full, "session_id", None), - "tags": getattr(trace_full, "tags", []), - "timestamp": str(timestamp) if timestamp else None, - "input": getattr(trace_full, "input", None), - "output": getattr(trace_full, "output", None), - "metadata": getattr(trace_full, "metadata", None), - "observations": [ - { - "id": obs.id, - "type": getattr(obs, "type", None), - "name": getattr(obs, "name", None), - "start_time": str(getattr(obs, "start_time", None)) if getattr(obs, "start_time", None) else None, - "end_time": str(getattr(obs, "end_time", None)) if getattr(obs, "end_time", None) else None, - "input": getattr(obs, "input", None), - "output": getattr(obs, "output", None), - "parent_observation_id": getattr(obs, "parent_observation_id", None), - "metadata": getattr(obs, "metadata", None), - } - for obs in getattr(trace_full, "observations", []) - ] - if hasattr(trace_full, "observations") - else [], - } - - -def parse_traces_to_rows(traces: List[Dict[str, Any]], include_tool_calls: bool = True) -> List[EvaluationRow]: - """Parse raw trace dicts to EvaluationRows using the same logic as get_evaluation_rows.""" - rows = [] - for trace in traces: - try: - row = convert_trace_dict_to_evaluation_row(trace, include_tool_calls) - if row: - rows.append(row) - except Exception as e: - print(f" Failed to convert trace {trace.get('id')}: {e}") - return rows - - -def print_row_details(row: EvaluationRow, index: int): - """Print details of a single EvaluationRow.""" - print(f"\n--- Row {index + 1} ---") - print(f"Row ID: {row.input_metadata.row_id}") - print( - f"Trace ID: {row.input_metadata.session_data.get('langfuse_trace_id') if row.input_metadata.session_data else None}" - ) - print(f"Rollout ID: {row.execution_metadata.rollout_id}") - print(f"Invocation ID: {row.execution_metadata.invocation_id}") - print(f"Experiment ID: {row.execution_metadata.experiment_id}") - print(f"Run ID: {row.execution_metadata.run_id}") - print(f"Finish Reason: {row.execution_metadata.finish_reason}") # NEW - print(f"Num messages: {len(row.messages)}") - print(f"Tools: {row.tools is not None}") - - print("\nMessages:") - for j, msg in enumerate(row.messages): - content_preview = str(msg.content)[:100] if msg.content else "(empty)" - tool_calls_info = f" [tool_calls: {len(msg.tool_calls)}]" if msg.tool_calls else "" - print(f" [{j}] {msg.role}: {content_preview}{tool_calls_info}") - - -def main(): - rollout_id = os.environ.get("ROLLOUT_ID", "test-test-test") - hours_back = int(os.environ.get("HOURS_BACK", "24")) - - print(f"Rollout ID: {rollout_id}") - print(f"Hours back: {hours_back}") - print("=" * 60) - - # Step 1: Fetch raw traces directly from Langfuse - print("\n[1] Fetching raw traces from Langfuse...") - traces = fetch_traces_from_langfuse( - tags=[f"rollout_id:{rollout_id}"], - limit=10, - hours_back=hours_back, - ) - - if not traces: - print("\nNo traces found!") - return - - # Step 2: Print raw trace structure (first trace only) - print("\n[2] Raw trace structure (first trace):") - print("-" * 60) - first_trace = traces[0] - print(f"ID: {first_trace.get('id')}") - print(f"Name: {first_trace.get('name')}") - print(f"Tags: {first_trace.get('tags')}") - print(f"Input type: {type(first_trace.get('input'))}") - print(f"Input: {json.dumps(first_trace.get('input'), indent=2)[:500]}...") - print(f"Output type: {type(first_trace.get('output'))}") - print(f"Output: {json.dumps(first_trace.get('output'), indent=2)[:500] if first_trace.get('output') else None}...") - print(f"Num observations: {len(first_trace.get('observations', []))}") - - # Print observations - for obs in first_trace.get("observations", []): - print(f"\n Observation: {obs.get('name')} ({obs.get('type')})") - print(f" Input type: {type(obs.get('input'))}") - print(f" Input: {json.dumps(obs.get('input'), indent=2)[:300] if obs.get('input') else None}...") - print(f" Output type: {type(obs.get('output'))}") - print(f" Output: {json.dumps(obs.get('output'), indent=2)[:300] if obs.get('output') else None}...") - - # Step 3: Parse to EvaluationRows - print("\n[3] Parsing traces to EvaluationRows...") - print("-" * 60) - rows = parse_traces_to_rows(traces) - - print(f"\nSuccessfully parsed {len(rows)} / {len(traces)} traces") - - # Step 4: Print row details - print("\n[4] EvaluationRow details:") - print("=" * 60) - for i, row in enumerate(rows): - print_row_details(row, i) - - -if __name__ == "__main__": - main() From fc1e273065cf9ddf13869ab5cd494a3f017e9946 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 26 Jan 2026 14:20:40 -0800 Subject: [PATCH 3/6] add test --- tests/remote_server/test_remote_fireworks.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index 43da29ed..145e97c3 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -141,4 +141,8 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat assert "data_loader_type" in row.input_metadata.dataset_info assert "data_loader_num_rows" in row.input_metadata.dataset_info + assert row.execution_metadata.finish_reason == "stop", ( + f"Expected finish_reason='stop', got {row.execution_metadata.finish_reason}" + ) + return row From a057dfc0d45d487cfe03c1be5946c06fa5de9222 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 26 Jan 2026 14:26:58 -0800 Subject: [PATCH 4/6] fix --- eval_protocol/adapters/fireworks_tracing.py | 22 +++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 4007f72b..3c701ab2 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -9,6 +9,7 @@ import requests from datetime import datetime import ast +import json import os from typing import Any, Dict, List, Optional, Protocol @@ -56,18 +57,23 @@ def extract_openai_response(observations: List[Dict[str, Any]]) -> Optional[Dict """ for obs in observations: if obs.get("name") == "raw_gen_ai_request" and obs.get("type") == "SPAN": - metadata = obs.get("metadata", {}) - attributes = metadata.get("attributes", {}) + metadata = obs.get("metadata") or {} + attributes = metadata.get("attributes") or {} result: Dict[str, Any] = {} for key, value in attributes.items(): - # Try to parse stringified Python literals, otherwise keep as-is + # Try to parse stringified objects (could be Python repr or JSON) if isinstance(value, str) and value.startswith(("[", "{")): try: result[key] = ast.literal_eval(value) - except Exception: - result[key] = value + except Exception as e: + logger.debug("Failed to parse %s with ast.literal_eval: %s", key, e) + try: + result[key] = json.loads(value) + except Exception as e: + logger.debug("Failed to parse %s with json.loads: %s", key, e) + result[key] = value else: result[key] = value @@ -129,7 +135,7 @@ def convert_trace_dict_to_evaluation_row( ): break # Break early if we've found all the metadata we need - observations = trace.get("observations", []) + observations = trace.get("observations") or [] # We can only extract when stored in OTEL format. openai_response = extract_openai_response(observations) if openai_response: @@ -201,7 +207,7 @@ def extract_messages_from_trace_dict( # Fallback: use the last GENERATION observation which typically contains full chat history if not messages: try: - all_observations = trace.get("observations", []) + all_observations = trace.get("observations") or [] gens = [obs for obs in all_observations if obs.get("type") == "GENERATION"] if gens: gens.sort(key=lambda x: x.get("start_time", "")) @@ -227,7 +233,7 @@ def get_final_generation_in_span_dict(trace: Dict[str, Any], span_name: str) -> The final generation dictionary, or None if not found """ # Get all observations from the trace - all_observations = trace.get("observations", []) + all_observations = trace.get("observations") or [] # Find a span with the given name that has generation children parent_span = None From ffc935137625e32c0e76ac02e3b4ccdda3549913 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 26 Jan 2026 15:01:19 -0800 Subject: [PATCH 5/6] add --- tests/remote_server/remote_server.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/remote_server/remote_server.py b/tests/remote_server/remote_server.py index 4ac4fd6c..c7655671 100644 --- a/tests/remote_server/remote_server.py +++ b/tests/remote_server/remote_server.py @@ -13,6 +13,9 @@ app = FastAPI() +# Configure logging for the remote server (required for INFO-level logs to be emitted) +logging.basicConfig(level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s") + # Attach Fireworks tracing handler to root logger fireworks_handler = FireworksTracingHttpHandler() logging.getLogger().addHandler(fireworks_handler) From b35c4bc77e18a035d19c19310d878af91b099432 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 26 Jan 2026 15:07:41 -0800 Subject: [PATCH 6/6] fix --- eval_protocol/reward_function.py | 1 - tests/remote_server/test_remote_fireworks.py | 39 +++++++++++++++++++- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/eval_protocol/reward_function.py b/eval_protocol/reward_function.py index 6bd11974..743d3c7c 100644 --- a/eval_protocol/reward_function.py +++ b/eval_protocol/reward_function.py @@ -12,7 +12,6 @@ from .models import EvaluateResult, MetricResult from .typed_interface import reward_function -logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) T = TypeVar("T", bound=Callable[..., EvaluateResult]) diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index 145e97c3..b196cb34 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -1,5 +1,6 @@ # AUTO SERVER STARTUP: Server is automatically started and stopped by the test +import logging import subprocess import socket import time @@ -19,10 +20,23 @@ ROLLOUT_IDS = set() +class StatusLogCaptureHandler(logging.Handler): + """Custom handler to capture status log messages.""" + + def __init__(self): + super().__init__() + self.status_100_messages: List[str] = [] + + def emit(self, record): + msg = record.getMessage() # Use getMessage(), not .message attribute + if "Found Fireworks log" in msg and "with status code 100" in msg: + self.status_100_messages.append(msg) + + @pytest.fixture(autouse=True) def check_rollout_coverage(monkeypatch): """ - Ensure we attempted to fetch remote traces for each rollout. + Ensure we attempted to fetch remote traces for each rollout and received status logs. This wraps the built-in default_fireworks_output_data_loader (without making it configurable) and tracks rollout_ids passed through its DataLoaderConfig. @@ -37,9 +51,32 @@ def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader: return original_loader(config) monkeypatch.setattr(remote_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader) + + # Add custom handler to capture status logs + status_handler = StatusLogCaptureHandler() + status_handler.setLevel(logging.INFO) + rrp_logger = logging.getLogger("eval_protocol.pytest.remote_rollout_processor") + rrp_logger.addHandler(status_handler) + # Ensure the logger level allows INFO messages through + original_level = rrp_logger.level + rrp_logger.setLevel(logging.INFO) + yield + + # Cleanup handler and restore level + rrp_logger.removeHandler(status_handler) + rrp_logger.setLevel(original_level) + + # After test completes, verify we saw status logs for all 3 rollouts assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" + # Check that we received "Found Fireworks log ... with status code 100" for each rollout + assert len(status_handler.status_100_messages) == 3, ( + f"Expected 3 'Found Fireworks log ... with status code 100' messages, but found {len(status_handler.status_100_messages)}. " + f"This means the status logs from the remote server were not received. " + f"Messages captured: {status_handler.status_100_messages}" + ) + def find_available_port() -> int: """Find an available port on localhost"""