-
Notifications
You must be signed in to change notification settings - Fork 12
perf: Add server-side filtering for ep logs to improve performance #408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,9 @@ def enable_debug_mode(): | |
| print("Debug mode enabled for all relevant loggers") | ||
|
|
||
|
|
||
| DEFAULT_MAX_LOGS_LIMIT = 1000 # Default limit for initial log load to prevent slowdowns | ||
|
|
||
|
|
||
| class WebSocketManager: | ||
| """Manages WebSocket connections and broadcasts messages.""" | ||
|
|
||
|
|
@@ -53,17 +56,42 @@ def __init__(self): | |
| self._broadcast_queue: Queue = Queue() | ||
| self._broadcast_task: Optional[asyncio.Task] = None | ||
| self._lock = threading.Lock() | ||
| # Track which invocation_ids each connection is subscribed to (None = all) | ||
| self._connection_filters: Dict[WebSocket, Optional[List[str]]] = {} | ||
|
|
||
| async def connect(self, websocket: WebSocket): | ||
| async def connect( | ||
| self, | ||
| websocket: WebSocket, | ||
| invocation_ids: Optional[List[str]] = None, | ||
| limit: Optional[int] = None, | ||
| ): | ||
| """ | ||
| Connect a WebSocket client and send initial logs. | ||
|
|
||
| Args: | ||
| websocket: The WebSocket connection | ||
| invocation_ids: Optional list of invocation_ids to filter logs | ||
| limit: Maximum number of logs to send initially (defaults to DEFAULT_MAX_LOGS_LIMIT) | ||
| """ | ||
| logger.debug("[WEBSOCKET_CONNECT] New websocket connection attempt") | ||
| await websocket.accept() | ||
| with self._lock: | ||
| self.active_connections.append(websocket) | ||
| self._connection_filters[websocket] = invocation_ids | ||
| connection_count = len(self.active_connections) | ||
| logger.info(f"[WEBSOCKET_CONNECT] WebSocket connected. Total connections: {connection_count}") | ||
| logger.info( | ||
| f"[WEBSOCKET_CONNECT] WebSocket connected. Total connections: {connection_count}, " | ||
| f"invocation_ids filter: {invocation_ids}, limit: {limit}" | ||
| ) | ||
|
|
||
| # Use provided limit or default | ||
| effective_limit = limit if limit is not None else DEFAULT_MAX_LOGS_LIMIT | ||
|
|
||
| logger.debug("[WEBSOCKET_CONNECT] Reading logs for initialization") | ||
| logs = default_logger.read() | ||
| logger.debug( | ||
| f"[WEBSOCKET_CONNECT] Reading logs for initialization with " | ||
| f"invocation_ids={invocation_ids}, limit={effective_limit}" | ||
| ) | ||
| logs = default_logger.read(invocation_ids=invocation_ids, limit=effective_limit) | ||
| logger.debug(f"[WEBSOCKET_CONNECT] Found {len(logs)} logs to send") | ||
|
|
||
| data = { | ||
|
|
@@ -82,16 +110,25 @@ def disconnect(self, websocket: WebSocket): | |
| logger.debug("[WEBSOCKET_DISCONNECT] Removed websocket from active connections") | ||
| else: | ||
| logger.debug("[WEBSOCKET_DISCONNECT] Websocket was not in active connections") | ||
| # Clean up connection filter | ||
| if websocket in self._connection_filters: | ||
| del self._connection_filters[websocket] | ||
| connection_count = len(self.active_connections) | ||
| logger.info(f"[WEBSOCKET_DISCONNECT] WebSocket disconnected. Total connections: {connection_count}") | ||
|
|
||
| def broadcast_row_upserted(self, row: "EvaluationRow"): | ||
| """Broadcast a row-upsert event to all connected clients. | ||
|
|
||
| Safe no-op if server loop is not running or there are no connections. | ||
| Messages are only sent to connections whose invocation_id filter matches the row, | ||
| or to connections with no filter (subscribed to all). | ||
| """ | ||
| rollout_id = row.execution_metadata.rollout_id if row.execution_metadata else "unknown" | ||
| logger.debug(f"[WEBSOCKET_BROADCAST] Starting broadcast for rollout_id: {rollout_id}") | ||
| row_invocation_id = row.execution_metadata.invocation_id if row.execution_metadata else None | ||
| logger.debug( | ||
| f"[WEBSOCKET_BROADCAST] Starting broadcast for rollout_id: {rollout_id}, " | ||
| f"invocation_id: {row_invocation_id}" | ||
| ) | ||
|
|
||
| with self._lock: | ||
| active_connections_count = len(self.active_connections) | ||
|
|
@@ -105,9 +142,9 @@ def broadcast_row_upserted(self, row: "EvaluationRow"): | |
| f"[WEBSOCKET_BROADCAST] Successfully serialized message (length: {len(json_message)}) for rollout_id: {rollout_id}" | ||
| ) | ||
|
|
||
| # Queue the message for broadcasting in the main event loop | ||
| # Queue the message for broadcasting in the main event loop, along with invocation_id for filtering | ||
| logger.debug(f"[WEBSOCKET_BROADCAST] Queuing message for broadcast for rollout_id: {rollout_id}") | ||
| self._broadcast_queue.put(json_message) | ||
| self._broadcast_queue.put((json_message, row_invocation_id)) | ||
| logger.debug(f"[WEBSOCKET_BROADCAST] Successfully queued message for rollout_id: {rollout_id}") | ||
| except Exception as e: | ||
| logger.error( | ||
|
|
@@ -121,15 +158,25 @@ async def _start_broadcast_loop(self): | |
| try: | ||
| # Wait for a message to be queued | ||
| logger.debug("[WEBSOCKET_BROADCAST_LOOP] Waiting for message from queue") | ||
| message_data = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get) | ||
| queue_item = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get) | ||
|
|
||
| # Queue item is a tuple of (json_message, row_invocation_id) | ||
| if isinstance(queue_item, tuple): | ||
| json_message, row_invocation_id = queue_item | ||
| else: | ||
| # Backward compatibility: if it's just a string, send to all | ||
| json_message = str(queue_item) | ||
| row_invocation_id = None | ||
|
|
||
| logger.debug( | ||
| f"[WEBSOCKET_BROADCAST_LOOP] Retrieved message from queue (length: {len(str(message_data))})" | ||
| f"[WEBSOCKET_BROADCAST_LOOP] Retrieved message from queue (length: {len(json_message)}), " | ||
| f"invocation_id: {row_invocation_id}" | ||
| ) | ||
|
|
||
| # Regular string message for all connections | ||
| logger.debug("[WEBSOCKET_BROADCAST_LOOP] Sending message to all connections") | ||
| await self._send_text_to_all_connections(str(message_data)) | ||
| logger.debug("[WEBSOCKET_BROADCAST_LOOP] Successfully sent message to all connections") | ||
| # Send message to connections that match the filter | ||
| logger.debug("[WEBSOCKET_BROADCAST_LOOP] Sending message to filtered connections") | ||
| await self._send_text_to_filtered_connections(json_message, row_invocation_id) | ||
| logger.debug("[WEBSOCKET_BROADCAST_LOOP] Successfully sent message to connections") | ||
|
|
||
| except Exception as e: | ||
| logger.error(f"[WEBSOCKET_BROADCAST_LOOP] Error in broadcast loop: {e}") | ||
|
|
@@ -138,28 +185,54 @@ async def _start_broadcast_loop(self): | |
| logger.info("[WEBSOCKET_BROADCAST_LOOP] Broadcast loop cancelled") | ||
| break | ||
|
|
||
| async def _send_text_to_all_connections(self, text: str): | ||
| async def _send_text_to_filtered_connections(self, text: str, row_invocation_id: Optional[str] = None): | ||
| """ | ||
| Send text to connections that match the invocation_id filter. | ||
|
|
||
| Args: | ||
| text: The message to send | ||
| row_invocation_id: The invocation_id of the row being sent. | ||
| Connections with no filter (None) receive all messages. | ||
| Connections with a filter only receive messages where row_invocation_id is in their filter. | ||
| """ | ||
| with self._lock: | ||
| connections = list(self.active_connections) | ||
| connection_filters = dict(self._connection_filters) | ||
|
|
||
| # Filter connections based on their subscribed invocation_ids | ||
| eligible_connections = [] | ||
| for conn in connections: | ||
| conn_filter = connection_filters.get(conn) | ||
| if conn_filter is None: | ||
| # No filter means subscribed to all | ||
| eligible_connections.append(conn) | ||
| elif row_invocation_id is not None and row_invocation_id in conn_filter: | ||
| # Row's invocation_id matches connection's filter | ||
| eligible_connections.append(conn) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Empty invocation_ids filter causes inconsistent behaviorLow Severity Empty Additional Locations (1) |
||
| # else: skip this connection | ||
|
|
||
| logger.debug( | ||
| f"[WEBSOCKET_SEND] Attempting to send to {len(eligible_connections)} of {len(connections)} connections " | ||
| f"(filtered by invocation_id: {row_invocation_id})" | ||
| ) | ||
|
|
||
| logger.debug(f"[WEBSOCKET_SEND] Attempting to send to {len(connections)} connections") | ||
|
|
||
| if not connections: | ||
| logger.debug("[WEBSOCKET_SEND] No connections available, skipping send") | ||
| if not eligible_connections: | ||
| logger.debug("[WEBSOCKET_SEND] No eligible connections, skipping send") | ||
| return | ||
|
|
||
| tasks = [] | ||
| failed_connections = [] | ||
| task_connections = [] # Track which connection each task corresponds to | ||
|
|
||
| for i, connection in enumerate(connections): | ||
| for i, connection in enumerate(eligible_connections): | ||
| try: | ||
| logger.debug(f"[WEBSOCKET_SEND] Preparing to send to connection {i}") | ||
| tasks.append(connection.send_text(text)) | ||
| task_connections.append(connection) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Failed connection preparation skips cleanupLow Severity In Additional Locations (1) |
||
| except Exception as e: | ||
| logger.error(f"[WEBSOCKET_SEND] Failed to prepare send to WebSocket {i}: {e}") | ||
| failed_connections.append(connection) | ||
|
|
||
| # Execute all sends in parallel | ||
| failed_connections = [] | ||
| if tasks: | ||
| logger.debug(f"[WEBSOCKET_SEND] Executing {len(tasks)} parallel sends") | ||
| results = await asyncio.gather(*tasks, return_exceptions=True) | ||
|
|
@@ -169,7 +242,7 @@ async def _send_text_to_all_connections(self, text: str): | |
| for i, result in enumerate(results): | ||
| if isinstance(result, Exception): | ||
| logger.error(f"[WEBSOCKET_SEND] Failed to send text to WebSocket {i}: {result}") | ||
| failed_connections.append(connections[i]) | ||
| failed_connections.append(task_connections[i]) | ||
| else: | ||
| logger.debug(f"[WEBSOCKET_SEND] Successfully sent to connection {i}") | ||
|
|
||
|
|
@@ -180,6 +253,8 @@ async def _send_text_to_all_connections(self, text: str): | |
| for connection in failed_connections: | ||
| try: | ||
| self.active_connections.remove(connection) | ||
| if connection in self._connection_filters: | ||
| del self._connection_filters[connection] | ||
| except ValueError: | ||
| pass | ||
|
|
||
|
|
@@ -393,7 +468,27 @@ def _setup_websocket_routes(self): | |
|
|
||
| @self.app.websocket("/ws") | ||
| async def websocket_endpoint(websocket: WebSocket): | ||
| await self.websocket_manager.connect(websocket) | ||
| # Parse query parameters from WebSocket connection URL | ||
| # invocation_ids: comma-separated list of invocation IDs to filter | ||
| # limit: maximum number of initial logs to load | ||
| query_params = websocket.query_params | ||
| invocation_ids_param = query_params.get("invocation_ids") | ||
| limit_param = query_params.get("limit") | ||
|
|
||
| invocation_ids: Optional[List[str]] = None | ||
| if invocation_ids_param: | ||
| invocation_ids = [id.strip() for id in invocation_ids_param.split(",") if id.strip()] | ||
| logger.info(f"[WEBSOCKET] Client filtering by invocation_ids: {invocation_ids}") | ||
|
|
||
| limit: Optional[int] = None | ||
| if limit_param: | ||
| try: | ||
| limit = int(limit_param) | ||
| logger.info(f"[WEBSOCKET] Client requested limit: {limit}") | ||
| except ValueError: | ||
| logger.warning(f"[WEBSOCKET] Invalid limit parameter: {limit_param}") | ||
|
|
||
| await self.websocket_manager.connect(websocket, invocation_ids=invocation_ids, limit=limit) | ||
| try: | ||
| while True: | ||
| # Keep connection alive (for evaluation row updates) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The abstract
DatasetLogger.readsignature now includesrollout_id,invocation_ids, andlimit, butLocalFSDatasetLoggerAdapterstill only definesread(self, row_id=...). Any caller that treats aLocalFSDatasetLoggerAdapteras aDatasetLoggerand passes the new keywords (for example,read(invocation_ids=[...])orread(rollout_id=...)) will now hit aTypeErrorfor unexpected keyword arguments. To preserve compatibility with existing adapters, they should be updated to accept the new parameters (even if they ignore them).Useful? React with 👍 / 👎.