Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions eval_protocol/dataset_logger/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import List, Optional

from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
from eval_protocol.dataset_logger.sqlite_dataset_logger_adapter import SqliteDatasetLoggerAdapter
Expand All @@ -14,7 +15,7 @@ class _NoOpLogger(DatasetLogger):
def log(self, row):
return None

def read(self, rollout_id=None):
def read(self, rollout_id=None, invocation_ids=None, limit=None):
return []

return _NoOpLogger()
Expand All @@ -33,8 +34,13 @@ def _get_logger(self):
def log(self, row):
return self._get_logger().log(row)

def read(self, rollout_id=None):
return self._get_logger().read(rollout_id)
def read(
self,
rollout_id: Optional[str] = None,
invocation_ids: Optional[List[str]] = None,
limit: Optional[int] = None,
):
return self._get_logger().read(rollout_id=rollout_id, invocation_ids=invocation_ids, limit=limit)


default_logger: DatasetLogger = _LazyLogger()
13 changes: 10 additions & 3 deletions eval_protocol/dataset_logger/dataset_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,19 @@ def log(self, row: "EvaluationRow") -> None:
pass

@abstractmethod
def read(self, row_id: Optional[str] = None) -> List["EvaluationRow"]:
def read(
self,
rollout_id: Optional[str] = None,
invocation_ids: Optional[List[str]] = None,
limit: Optional[int] = None,
Comment on lines 26 to +31

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep DatasetLogger implementations aligned with new read() args

The abstract DatasetLogger.read signature now includes rollout_id, invocation_ids, and limit, but LocalFSDatasetLoggerAdapter still only defines read(self, row_id=...). Any caller that treats a LocalFSDatasetLoggerAdapter as a DatasetLogger and passes the new keywords (for example, read(invocation_ids=[...]) or read(rollout_id=...)) will now hit a TypeError for unexpected keyword arguments. To preserve compatibility with existing adapters, they should be updated to accept the new parameters (even if they ignore them).

Useful? React with 👍 / 👎.

) -> List["EvaluationRow"]:
"""
Retrieve EvaluationRow logs.
Retrieve EvaluationRow logs with optional filtering.

Args:
row_id (Optional[str]): If provided, filter logs by this row_id.
rollout_id (Optional[str]): If provided, filter logs by this rollout_id.
invocation_ids (Optional[List[str]]): If provided, filter logs by these invocation_ids.
limit (Optional[int]): If provided, limit the number of rows returned (most recent first).

Returns:
List[EvaluationRow]: List of retrieved evaluation rows.
Expand Down
9 changes: 7 additions & 2 deletions eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@ def log(self, row: "EvaluationRow") -> None:
logger.error(f"[EVENT_BUS_EMIT] Failed to emit row_upserted event for rollout_id {rollout_id}: {e}")
pass

def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]:
def read(
self,
rollout_id: Optional[str] = None,
invocation_ids: Optional[List[str]] = None,
limit: Optional[int] = None,
) -> List["EvaluationRow"]:
from eval_protocol.models import EvaluationRow

results = self._store.read_rows(rollout_id=rollout_id)
results = self._store.read_rows(rollout_id=rollout_id, invocation_ids=invocation_ids, limit=limit)
return [EvaluationRow(**data) for data in results]
57 changes: 50 additions & 7 deletions eval_protocol/dataset_logger/sqlite_evaluation_row_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import List, Optional

from peewee import CharField, Model, SqliteDatabase
from peewee import CharField, Model, SqliteDatabase, fn, SQL
from playhouse.sqlite_ext import JSONField

from eval_protocol.event_bus.sqlite_event_bus_database import (
Expand Down Expand Up @@ -67,12 +67,55 @@ def _do_upsert(self, rollout_id: str, data: dict) -> None:
else:
self._EvaluationRow.create(rollout_id=rollout_id, data=data)

def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]:
if rollout_id is None:
query = self._EvaluationRow.select().dicts()
else:
query = self._EvaluationRow.select().dicts().where(self._EvaluationRow.rollout_id == rollout_id)
results = list(query)
def read_rows(
self,
rollout_id: Optional[str] = None,
invocation_ids: Optional[List[str]] = None,
limit: Optional[int] = None,
) -> List[dict]:
"""
Read evaluation rows from the database with optional filtering.

Args:
rollout_id: Filter by a specific rollout_id (exact match)
invocation_ids: Filter by a list of invocation_ids (rows matching any)
limit: Maximum number of rows to return (most recent first)

Returns:
List of evaluation row data dictionaries
"""
query = self._EvaluationRow.select()

if rollout_id is not None:
query = query.where(self._EvaluationRow.rollout_id == rollout_id)

# Apply invocation_ids filter using JSON extraction
# Note: This filters rows where data->'execution_metadata'->>'invocation_id' matches any of the provided IDs
if invocation_ids is not None and len(invocation_ids) > 0:
# Build a condition that matches any of the invocation_ids
# Using SQLite JSON extraction: json_extract(data, '$.execution_metadata.invocation_id')
invocation_conditions = []
for inv_id in invocation_ids:
invocation_conditions.append(
fn.json_extract(self._EvaluationRow.data, "$.execution_metadata.invocation_id") == inv_id
)
# Combine with OR
if len(invocation_conditions) == 1:
query = query.where(invocation_conditions[0])
else:
from functools import reduce
from operator import or_

combined_condition = reduce(or_, invocation_conditions)
query = query.where(combined_condition)

# Order by rowid descending to get most recent rows first
query = query.order_by(SQL("rowid DESC"))

if limit is not None:
query = query.limit(limit)

results = list(query.dicts())
return [result["data"] for result in results]

def delete_row(self, rollout_id: str) -> int:
Expand Down
141 changes: 118 additions & 23 deletions eval_protocol/utils/logs_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def enable_debug_mode():
print("Debug mode enabled for all relevant loggers")


DEFAULT_MAX_LOGS_LIMIT = 1000 # Default limit for initial log load to prevent slowdowns


class WebSocketManager:
"""Manages WebSocket connections and broadcasts messages."""

Expand All @@ -53,17 +56,42 @@ def __init__(self):
self._broadcast_queue: Queue = Queue()
self._broadcast_task: Optional[asyncio.Task] = None
self._lock = threading.Lock()
# Track which invocation_ids each connection is subscribed to (None = all)
self._connection_filters: Dict[WebSocket, Optional[List[str]]] = {}

async def connect(self, websocket: WebSocket):
async def connect(
self,
websocket: WebSocket,
invocation_ids: Optional[List[str]] = None,
limit: Optional[int] = None,
):
"""
Connect a WebSocket client and send initial logs.

Args:
websocket: The WebSocket connection
invocation_ids: Optional list of invocation_ids to filter logs
limit: Maximum number of logs to send initially (defaults to DEFAULT_MAX_LOGS_LIMIT)
"""
logger.debug("[WEBSOCKET_CONNECT] New websocket connection attempt")
await websocket.accept()
with self._lock:
self.active_connections.append(websocket)
self._connection_filters[websocket] = invocation_ids
connection_count = len(self.active_connections)
logger.info(f"[WEBSOCKET_CONNECT] WebSocket connected. Total connections: {connection_count}")
logger.info(
f"[WEBSOCKET_CONNECT] WebSocket connected. Total connections: {connection_count}, "
f"invocation_ids filter: {invocation_ids}, limit: {limit}"
)

# Use provided limit or default
effective_limit = limit if limit is not None else DEFAULT_MAX_LOGS_LIMIT

logger.debug("[WEBSOCKET_CONNECT] Reading logs for initialization")
logs = default_logger.read()
logger.debug(
f"[WEBSOCKET_CONNECT] Reading logs for initialization with "
f"invocation_ids={invocation_ids}, limit={effective_limit}"
)
logs = default_logger.read(invocation_ids=invocation_ids, limit=effective_limit)
logger.debug(f"[WEBSOCKET_CONNECT] Found {len(logs)} logs to send")

data = {
Expand All @@ -82,16 +110,25 @@ def disconnect(self, websocket: WebSocket):
logger.debug("[WEBSOCKET_DISCONNECT] Removed websocket from active connections")
else:
logger.debug("[WEBSOCKET_DISCONNECT] Websocket was not in active connections")
# Clean up connection filter
if websocket in self._connection_filters:
del self._connection_filters[websocket]
connection_count = len(self.active_connections)
logger.info(f"[WEBSOCKET_DISCONNECT] WebSocket disconnected. Total connections: {connection_count}")

def broadcast_row_upserted(self, row: "EvaluationRow"):
"""Broadcast a row-upsert event to all connected clients.

Safe no-op if server loop is not running or there are no connections.
Messages are only sent to connections whose invocation_id filter matches the row,
or to connections with no filter (subscribed to all).
"""
rollout_id = row.execution_metadata.rollout_id if row.execution_metadata else "unknown"
logger.debug(f"[WEBSOCKET_BROADCAST] Starting broadcast for rollout_id: {rollout_id}")
row_invocation_id = row.execution_metadata.invocation_id if row.execution_metadata else None
logger.debug(
f"[WEBSOCKET_BROADCAST] Starting broadcast for rollout_id: {rollout_id}, "
f"invocation_id: {row_invocation_id}"
)

with self._lock:
active_connections_count = len(self.active_connections)
Expand All @@ -105,9 +142,9 @@ def broadcast_row_upserted(self, row: "EvaluationRow"):
f"[WEBSOCKET_BROADCAST] Successfully serialized message (length: {len(json_message)}) for rollout_id: {rollout_id}"
)

# Queue the message for broadcasting in the main event loop
# Queue the message for broadcasting in the main event loop, along with invocation_id for filtering
logger.debug(f"[WEBSOCKET_BROADCAST] Queuing message for broadcast for rollout_id: {rollout_id}")
self._broadcast_queue.put(json_message)
self._broadcast_queue.put((json_message, row_invocation_id))
logger.debug(f"[WEBSOCKET_BROADCAST] Successfully queued message for rollout_id: {rollout_id}")
except Exception as e:
logger.error(
Expand All @@ -121,15 +158,25 @@ async def _start_broadcast_loop(self):
try:
# Wait for a message to be queued
logger.debug("[WEBSOCKET_BROADCAST_LOOP] Waiting for message from queue")
message_data = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get)
queue_item = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get)

# Queue item is a tuple of (json_message, row_invocation_id)
if isinstance(queue_item, tuple):
json_message, row_invocation_id = queue_item
else:
# Backward compatibility: if it's just a string, send to all
json_message = str(queue_item)
row_invocation_id = None

logger.debug(
f"[WEBSOCKET_BROADCAST_LOOP] Retrieved message from queue (length: {len(str(message_data))})"
f"[WEBSOCKET_BROADCAST_LOOP] Retrieved message from queue (length: {len(json_message)}), "
f"invocation_id: {row_invocation_id}"
)

# Regular string message for all connections
logger.debug("[WEBSOCKET_BROADCAST_LOOP] Sending message to all connections")
await self._send_text_to_all_connections(str(message_data))
logger.debug("[WEBSOCKET_BROADCAST_LOOP] Successfully sent message to all connections")
# Send message to connections that match the filter
logger.debug("[WEBSOCKET_BROADCAST_LOOP] Sending message to filtered connections")
await self._send_text_to_filtered_connections(json_message, row_invocation_id)
logger.debug("[WEBSOCKET_BROADCAST_LOOP] Successfully sent message to connections")

except Exception as e:
logger.error(f"[WEBSOCKET_BROADCAST_LOOP] Error in broadcast loop: {e}")
Expand All @@ -138,28 +185,54 @@ async def _start_broadcast_loop(self):
logger.info("[WEBSOCKET_BROADCAST_LOOP] Broadcast loop cancelled")
break

async def _send_text_to_all_connections(self, text: str):
async def _send_text_to_filtered_connections(self, text: str, row_invocation_id: Optional[str] = None):
"""
Send text to connections that match the invocation_id filter.

Args:
text: The message to send
row_invocation_id: The invocation_id of the row being sent.
Connections with no filter (None) receive all messages.
Connections with a filter only receive messages where row_invocation_id is in their filter.
"""
with self._lock:
connections = list(self.active_connections)
connection_filters = dict(self._connection_filters)

# Filter connections based on their subscribed invocation_ids
eligible_connections = []
for conn in connections:
conn_filter = connection_filters.get(conn)
if conn_filter is None:
# No filter means subscribed to all
eligible_connections.append(conn)
elif row_invocation_id is not None and row_invocation_id in conn_filter:
# Row's invocation_id matches connection's filter
eligible_connections.append(conn)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty invocation_ids filter causes inconsistent behavior

Low Severity

Empty invocation_ids list is handled inconsistently between initial load and real-time updates. In read_rows(), the condition len(invocation_ids) > 0 treats an empty list as "no filter" (returns all rows). But in _send_text_to_filtered_connections(), the check conn_filter is None is False for [], and row_invocation_id in [] is always False, so the connection receives no updates. This causes a scenario where a WebSocket connection with ?invocation_ids=, receives all initial logs but never receives real-time broadcasts.

Additional Locations (1)

Fix in Cursor Fix in Web

# else: skip this connection

logger.debug(
f"[WEBSOCKET_SEND] Attempting to send to {len(eligible_connections)} of {len(connections)} connections "
f"(filtered by invocation_id: {row_invocation_id})"
)

logger.debug(f"[WEBSOCKET_SEND] Attempting to send to {len(connections)} connections")

if not connections:
logger.debug("[WEBSOCKET_SEND] No connections available, skipping send")
if not eligible_connections:
logger.debug("[WEBSOCKET_SEND] No eligible connections, skipping send")
return

tasks = []
failed_connections = []
task_connections = [] # Track which connection each task corresponds to

for i, connection in enumerate(connections):
for i, connection in enumerate(eligible_connections):
try:
logger.debug(f"[WEBSOCKET_SEND] Preparing to send to connection {i}")
tasks.append(connection.send_text(text))
task_connections.append(connection)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failed connection preparation skips cleanup

Low Severity

In _send_text_to_filtered_connections, when an exception occurs during coroutine preparation (lines 231-232), the connection is no longer added to failed_connections for cleanup. The original code added failing connections to failed_connections in the except block, but this was removed during refactoring. Now failed_connections is initialized after the loop (line 235), making it impossible to track preparation failures. Connections that fail during preparation will remain in active_connections and _connection_filters until they fail during an actual send attempt.

Additional Locations (1)

Fix in Cursor Fix in Web

except Exception as e:
logger.error(f"[WEBSOCKET_SEND] Failed to prepare send to WebSocket {i}: {e}")
failed_connections.append(connection)

# Execute all sends in parallel
failed_connections = []
if tasks:
logger.debug(f"[WEBSOCKET_SEND] Executing {len(tasks)} parallel sends")
results = await asyncio.gather(*tasks, return_exceptions=True)
Expand All @@ -169,7 +242,7 @@ async def _send_text_to_all_connections(self, text: str):
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"[WEBSOCKET_SEND] Failed to send text to WebSocket {i}: {result}")
failed_connections.append(connections[i])
failed_connections.append(task_connections[i])
else:
logger.debug(f"[WEBSOCKET_SEND] Successfully sent to connection {i}")

Expand All @@ -180,6 +253,8 @@ async def _send_text_to_all_connections(self, text: str):
for connection in failed_connections:
try:
self.active_connections.remove(connection)
if connection in self._connection_filters:
del self._connection_filters[connection]
except ValueError:
pass

Expand Down Expand Up @@ -393,7 +468,27 @@ def _setup_websocket_routes(self):

@self.app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await self.websocket_manager.connect(websocket)
# Parse query parameters from WebSocket connection URL
# invocation_ids: comma-separated list of invocation IDs to filter
# limit: maximum number of initial logs to load
query_params = websocket.query_params
invocation_ids_param = query_params.get("invocation_ids")
limit_param = query_params.get("limit")

invocation_ids: Optional[List[str]] = None
if invocation_ids_param:
invocation_ids = [id.strip() for id in invocation_ids_param.split(",") if id.strip()]
logger.info(f"[WEBSOCKET] Client filtering by invocation_ids: {invocation_ids}")

limit: Optional[int] = None
if limit_param:
try:
limit = int(limit_param)
logger.info(f"[WEBSOCKET] Client requested limit: {limit}")
except ValueError:
logger.warning(f"[WEBSOCKET] Invalid limit parameter: {limit_param}")

await self.websocket_manager.connect(websocket, invocation_ids=invocation_ids, limit=limit)
try:
while True:
# Keep connection alive (for evaluation row updates)
Expand Down
Loading
Loading