diff --git a/eval_protocol/dataset_logger/__init__.py b/eval_protocol/dataset_logger/__init__.py index b3fc1cb2..c3f089e5 100644 --- a/eval_protocol/dataset_logger/__init__.py +++ b/eval_protocol/dataset_logger/__init__.py @@ -1,4 +1,5 @@ import os +from typing import List, Optional from eval_protocol.dataset_logger.dataset_logger import DatasetLogger from eval_protocol.dataset_logger.sqlite_dataset_logger_adapter import SqliteDatasetLoggerAdapter @@ -14,7 +15,7 @@ class _NoOpLogger(DatasetLogger): def log(self, row): return None - def read(self, rollout_id=None): + def read(self, rollout_id=None, invocation_ids=None, limit=None): return [] return _NoOpLogger() @@ -33,8 +34,13 @@ def _get_logger(self): def log(self, row): return self._get_logger().log(row) - def read(self, rollout_id=None): - return self._get_logger().read(rollout_id) + def read( + self, + rollout_id: Optional[str] = None, + invocation_ids: Optional[List[str]] = None, + limit: Optional[int] = None, + ): + return self._get_logger().read(rollout_id=rollout_id, invocation_ids=invocation_ids, limit=limit) default_logger: DatasetLogger = _LazyLogger() diff --git a/eval_protocol/dataset_logger/dataset_logger.py b/eval_protocol/dataset_logger/dataset_logger.py index ac735b10..bfbc69b4 100644 --- a/eval_protocol/dataset_logger/dataset_logger.py +++ b/eval_protocol/dataset_logger/dataset_logger.py @@ -24,12 +24,19 @@ def log(self, row: "EvaluationRow") -> None: pass @abstractmethod - def read(self, row_id: Optional[str] = None) -> List["EvaluationRow"]: + def read( + self, + rollout_id: Optional[str] = None, + invocation_ids: Optional[List[str]] = None, + limit: Optional[int] = None, + ) -> List["EvaluationRow"]: """ - Retrieve EvaluationRow logs. + Retrieve EvaluationRow logs with optional filtering. Args: - row_id (Optional[str]): If provided, filter logs by this row_id. + rollout_id (Optional[str]): If provided, filter logs by this rollout_id. + invocation_ids (Optional[List[str]]): If provided, filter logs by these invocation_ids. + limit (Optional[int]): If provided, limit the number of rows returned (most recent first). Returns: List[EvaluationRow]: List of retrieved evaluation rows. diff --git a/eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py b/eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py index 5f360bfc..e5943ddf 100644 --- a/eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py +++ b/eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py @@ -38,8 +38,13 @@ def log(self, row: "EvaluationRow") -> None: logger.error(f"[EVENT_BUS_EMIT] Failed to emit row_upserted event for rollout_id {rollout_id}: {e}") pass - def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]: + def read( + self, + rollout_id: Optional[str] = None, + invocation_ids: Optional[List[str]] = None, + limit: Optional[int] = None, + ) -> List["EvaluationRow"]: from eval_protocol.models import EvaluationRow - results = self._store.read_rows(rollout_id=rollout_id) + results = self._store.read_rows(rollout_id=rollout_id, invocation_ids=invocation_ids, limit=limit) return [EvaluationRow(**data) for data in results] diff --git a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py index f6a81e1e..ce4a8419 100644 --- a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py +++ b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py @@ -1,7 +1,7 @@ import os from typing import List, Optional -from peewee import CharField, Model, SqliteDatabase +from peewee import CharField, Model, SqliteDatabase, fn, SQL from playhouse.sqlite_ext import JSONField from eval_protocol.event_bus.sqlite_event_bus_database import ( @@ -67,12 +67,55 @@ def _do_upsert(self, rollout_id: str, data: dict) -> None: else: self._EvaluationRow.create(rollout_id=rollout_id, data=data) - def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]: - if rollout_id is None: - query = self._EvaluationRow.select().dicts() - else: - query = self._EvaluationRow.select().dicts().where(self._EvaluationRow.rollout_id == rollout_id) - results = list(query) + def read_rows( + self, + rollout_id: Optional[str] = None, + invocation_ids: Optional[List[str]] = None, + limit: Optional[int] = None, + ) -> List[dict]: + """ + Read evaluation rows from the database with optional filtering. + + Args: + rollout_id: Filter by a specific rollout_id (exact match) + invocation_ids: Filter by a list of invocation_ids (rows matching any) + limit: Maximum number of rows to return (most recent first) + + Returns: + List of evaluation row data dictionaries + """ + query = self._EvaluationRow.select() + + if rollout_id is not None: + query = query.where(self._EvaluationRow.rollout_id == rollout_id) + + # Apply invocation_ids filter using JSON extraction + # Note: This filters rows where data->'execution_metadata'->>'invocation_id' matches any of the provided IDs + if invocation_ids is not None and len(invocation_ids) > 0: + # Build a condition that matches any of the invocation_ids + # Using SQLite JSON extraction: json_extract(data, '$.execution_metadata.invocation_id') + invocation_conditions = [] + for inv_id in invocation_ids: + invocation_conditions.append( + fn.json_extract(self._EvaluationRow.data, "$.execution_metadata.invocation_id") == inv_id + ) + # Combine with OR + if len(invocation_conditions) == 1: + query = query.where(invocation_conditions[0]) + else: + from functools import reduce + from operator import or_ + + combined_condition = reduce(or_, invocation_conditions) + query = query.where(combined_condition) + + # Order by rowid descending to get most recent rows first + query = query.order_by(SQL("rowid DESC")) + + if limit is not None: + query = query.limit(limit) + + results = list(query.dicts()) return [result["data"] for result in results] def delete_row(self, rollout_id: str) -> int: diff --git a/eval_protocol/utils/logs_server.py b/eval_protocol/utils/logs_server.py index adf44c57..a17b3ed8 100644 --- a/eval_protocol/utils/logs_server.py +++ b/eval_protocol/utils/logs_server.py @@ -45,6 +45,9 @@ def enable_debug_mode(): print("Debug mode enabled for all relevant loggers") +DEFAULT_MAX_LOGS_LIMIT = 1000 # Default limit for initial log load to prevent slowdowns + + class WebSocketManager: """Manages WebSocket connections and broadcasts messages.""" @@ -53,17 +56,42 @@ def __init__(self): self._broadcast_queue: Queue = Queue() self._broadcast_task: Optional[asyncio.Task] = None self._lock = threading.Lock() + # Track which invocation_ids each connection is subscribed to (None = all) + self._connection_filters: Dict[WebSocket, Optional[List[str]]] = {} - async def connect(self, websocket: WebSocket): + async def connect( + self, + websocket: WebSocket, + invocation_ids: Optional[List[str]] = None, + limit: Optional[int] = None, + ): + """ + Connect a WebSocket client and send initial logs. + + Args: + websocket: The WebSocket connection + invocation_ids: Optional list of invocation_ids to filter logs + limit: Maximum number of logs to send initially (defaults to DEFAULT_MAX_LOGS_LIMIT) + """ logger.debug("[WEBSOCKET_CONNECT] New websocket connection attempt") await websocket.accept() with self._lock: self.active_connections.append(websocket) + self._connection_filters[websocket] = invocation_ids connection_count = len(self.active_connections) - logger.info(f"[WEBSOCKET_CONNECT] WebSocket connected. Total connections: {connection_count}") + logger.info( + f"[WEBSOCKET_CONNECT] WebSocket connected. Total connections: {connection_count}, " + f"invocation_ids filter: {invocation_ids}, limit: {limit}" + ) + + # Use provided limit or default + effective_limit = limit if limit is not None else DEFAULT_MAX_LOGS_LIMIT - logger.debug("[WEBSOCKET_CONNECT] Reading logs for initialization") - logs = default_logger.read() + logger.debug( + f"[WEBSOCKET_CONNECT] Reading logs for initialization with " + f"invocation_ids={invocation_ids}, limit={effective_limit}" + ) + logs = default_logger.read(invocation_ids=invocation_ids, limit=effective_limit) logger.debug(f"[WEBSOCKET_CONNECT] Found {len(logs)} logs to send") data = { @@ -82,6 +110,9 @@ def disconnect(self, websocket: WebSocket): logger.debug("[WEBSOCKET_DISCONNECT] Removed websocket from active connections") else: logger.debug("[WEBSOCKET_DISCONNECT] Websocket was not in active connections") + # Clean up connection filter + if websocket in self._connection_filters: + del self._connection_filters[websocket] connection_count = len(self.active_connections) logger.info(f"[WEBSOCKET_DISCONNECT] WebSocket disconnected. Total connections: {connection_count}") @@ -89,9 +120,15 @@ def broadcast_row_upserted(self, row: "EvaluationRow"): """Broadcast a row-upsert event to all connected clients. Safe no-op if server loop is not running or there are no connections. + Messages are only sent to connections whose invocation_id filter matches the row, + or to connections with no filter (subscribed to all). """ rollout_id = row.execution_metadata.rollout_id if row.execution_metadata else "unknown" - logger.debug(f"[WEBSOCKET_BROADCAST] Starting broadcast for rollout_id: {rollout_id}") + row_invocation_id = row.execution_metadata.invocation_id if row.execution_metadata else None + logger.debug( + f"[WEBSOCKET_BROADCAST] Starting broadcast for rollout_id: {rollout_id}, " + f"invocation_id: {row_invocation_id}" + ) with self._lock: active_connections_count = len(self.active_connections) @@ -105,9 +142,9 @@ def broadcast_row_upserted(self, row: "EvaluationRow"): f"[WEBSOCKET_BROADCAST] Successfully serialized message (length: {len(json_message)}) for rollout_id: {rollout_id}" ) - # Queue the message for broadcasting in the main event loop + # Queue the message for broadcasting in the main event loop, along with invocation_id for filtering logger.debug(f"[WEBSOCKET_BROADCAST] Queuing message for broadcast for rollout_id: {rollout_id}") - self._broadcast_queue.put(json_message) + self._broadcast_queue.put((json_message, row_invocation_id)) logger.debug(f"[WEBSOCKET_BROADCAST] Successfully queued message for rollout_id: {rollout_id}") except Exception as e: logger.error( @@ -121,15 +158,25 @@ async def _start_broadcast_loop(self): try: # Wait for a message to be queued logger.debug("[WEBSOCKET_BROADCAST_LOOP] Waiting for message from queue") - message_data = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get) + queue_item = await asyncio.get_event_loop().run_in_executor(None, self._broadcast_queue.get) + + # Queue item is a tuple of (json_message, row_invocation_id) + if isinstance(queue_item, tuple): + json_message, row_invocation_id = queue_item + else: + # Backward compatibility: if it's just a string, send to all + json_message = str(queue_item) + row_invocation_id = None + logger.debug( - f"[WEBSOCKET_BROADCAST_LOOP] Retrieved message from queue (length: {len(str(message_data))})" + f"[WEBSOCKET_BROADCAST_LOOP] Retrieved message from queue (length: {len(json_message)}), " + f"invocation_id: {row_invocation_id}" ) - # Regular string message for all connections - logger.debug("[WEBSOCKET_BROADCAST_LOOP] Sending message to all connections") - await self._send_text_to_all_connections(str(message_data)) - logger.debug("[WEBSOCKET_BROADCAST_LOOP] Successfully sent message to all connections") + # Send message to connections that match the filter + logger.debug("[WEBSOCKET_BROADCAST_LOOP] Sending message to filtered connections") + await self._send_text_to_filtered_connections(json_message, row_invocation_id) + logger.debug("[WEBSOCKET_BROADCAST_LOOP] Successfully sent message to connections") except Exception as e: logger.error(f"[WEBSOCKET_BROADCAST_LOOP] Error in broadcast loop: {e}") @@ -138,28 +185,54 @@ async def _start_broadcast_loop(self): logger.info("[WEBSOCKET_BROADCAST_LOOP] Broadcast loop cancelled") break - async def _send_text_to_all_connections(self, text: str): + async def _send_text_to_filtered_connections(self, text: str, row_invocation_id: Optional[str] = None): + """ + Send text to connections that match the invocation_id filter. + + Args: + text: The message to send + row_invocation_id: The invocation_id of the row being sent. + Connections with no filter (None) receive all messages. + Connections with a filter only receive messages where row_invocation_id is in their filter. + """ with self._lock: connections = list(self.active_connections) + connection_filters = dict(self._connection_filters) + + # Filter connections based on their subscribed invocation_ids + eligible_connections = [] + for conn in connections: + conn_filter = connection_filters.get(conn) + if conn_filter is None: + # No filter means subscribed to all + eligible_connections.append(conn) + elif row_invocation_id is not None and row_invocation_id in conn_filter: + # Row's invocation_id matches connection's filter + eligible_connections.append(conn) + # else: skip this connection + + logger.debug( + f"[WEBSOCKET_SEND] Attempting to send to {len(eligible_connections)} of {len(connections)} connections " + f"(filtered by invocation_id: {row_invocation_id})" + ) - logger.debug(f"[WEBSOCKET_SEND] Attempting to send to {len(connections)} connections") - - if not connections: - logger.debug("[WEBSOCKET_SEND] No connections available, skipping send") + if not eligible_connections: + logger.debug("[WEBSOCKET_SEND] No eligible connections, skipping send") return tasks = [] - failed_connections = [] + task_connections = [] # Track which connection each task corresponds to - for i, connection in enumerate(connections): + for i, connection in enumerate(eligible_connections): try: logger.debug(f"[WEBSOCKET_SEND] Preparing to send to connection {i}") tasks.append(connection.send_text(text)) + task_connections.append(connection) except Exception as e: logger.error(f"[WEBSOCKET_SEND] Failed to prepare send to WebSocket {i}: {e}") - failed_connections.append(connection) # Execute all sends in parallel + failed_connections = [] if tasks: logger.debug(f"[WEBSOCKET_SEND] Executing {len(tasks)} parallel sends") results = await asyncio.gather(*tasks, return_exceptions=True) @@ -169,7 +242,7 @@ async def _send_text_to_all_connections(self, text: str): for i, result in enumerate(results): if isinstance(result, Exception): logger.error(f"[WEBSOCKET_SEND] Failed to send text to WebSocket {i}: {result}") - failed_connections.append(connections[i]) + failed_connections.append(task_connections[i]) else: logger.debug(f"[WEBSOCKET_SEND] Successfully sent to connection {i}") @@ -180,6 +253,8 @@ async def _send_text_to_all_connections(self, text: str): for connection in failed_connections: try: self.active_connections.remove(connection) + if connection in self._connection_filters: + del self._connection_filters[connection] except ValueError: pass @@ -393,7 +468,27 @@ def _setup_websocket_routes(self): @self.app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): - await self.websocket_manager.connect(websocket) + # Parse query parameters from WebSocket connection URL + # invocation_ids: comma-separated list of invocation IDs to filter + # limit: maximum number of initial logs to load + query_params = websocket.query_params + invocation_ids_param = query_params.get("invocation_ids") + limit_param = query_params.get("limit") + + invocation_ids: Optional[List[str]] = None + if invocation_ids_param: + invocation_ids = [id.strip() for id in invocation_ids_param.split(",") if id.strip()] + logger.info(f"[WEBSOCKET] Client filtering by invocation_ids: {invocation_ids}") + + limit: Optional[int] = None + if limit_param: + try: + limit = int(limit_param) + logger.info(f"[WEBSOCKET] Client requested limit: {limit}") + except ValueError: + logger.warning(f"[WEBSOCKET] Invalid limit parameter: {limit_param}") + + await self.websocket_manager.connect(websocket, invocation_ids=invocation_ids, limit=limit) try: while True: # Keep connection alive (for evaluation row updates) diff --git a/tests/dataset_logger/test_sqlite_dataset_logger_adapter.py b/tests/dataset_logger/test_sqlite_dataset_logger_adapter.py index 7689d085..78af8dc0 100644 --- a/tests/dataset_logger/test_sqlite_dataset_logger_adapter.py +++ b/tests/dataset_logger/test_sqlite_dataset_logger_adapter.py @@ -72,9 +72,83 @@ def test_create_multiple_logs_and_read_all(): # Verify we got all 3 rows back assert len(saved_rows) == 3 - # Verify each row matches the original + # Build a map of saved rows by row_id for order-independent comparison + # (read() now returns rows in descending order by insertion time) + saved_by_row_id = {r.input_metadata.row_id: r for r in saved_rows} + + # Verify each row matches the original (order-independent) for i, original_row in enumerate(rows): - saved_row = saved_rows[i] + row_id = f"row_{i}" + saved_row = saved_by_row_id[row_id] assert original_row.messages == saved_row.messages assert original_row.input_metadata == saved_row.input_metadata - assert original_row.input_metadata.row_id == f"row_{i}" + assert saved_row.input_metadata.row_id == row_id + + +def test_read_with_invocation_ids_filter(): + """Test filtering rows by invocation_ids.""" + from eval_protocol.models import ExecutionMetadata + + db_path = get_db_path("test_read_with_invocation_ids_filter") + # delete the db file if it exists + if os.path.exists(db_path): + os.remove(db_path) + store = SqliteEvaluationRowStore(db_path=db_path) + logger = SqliteDatasetLoggerAdapter(store=store) + + # Create rows with different invocation_ids + inv_ids = ["inv-alpha", "inv-beta", "inv-gamma"] + for i, inv_id in enumerate(inv_ids): + messages = [Message(role="user", content=f"Hello {inv_id}")] + input_metadata = InputMetadata(row_id=f"row_{i}") + execution_metadata = ExecutionMetadata(invocation_id=inv_id) + row = EvaluationRow( + input_metadata=input_metadata, + messages=messages, + execution_metadata=execution_metadata, + ) + logger.log(row) + + # Test 1: Read all (no filter) + all_rows = logger.read() + assert len(all_rows) == 3 + + # Test 2: Filter by single invocation_id + filtered_rows = logger.read(invocation_ids=["inv-alpha"]) + assert len(filtered_rows) == 1 + assert filtered_rows[0].execution_metadata.invocation_id == "inv-alpha" + + # Test 3: Filter by multiple invocation_ids + filtered_rows = logger.read(invocation_ids=["inv-alpha", "inv-gamma"]) + assert len(filtered_rows) == 2 + inv_ids_found = {r.execution_metadata.invocation_id for r in filtered_rows} + assert inv_ids_found == {"inv-alpha", "inv-gamma"} + + # Test 4: Filter by non-existent invocation_id + filtered_rows = logger.read(invocation_ids=["inv-nonexistent"]) + assert len(filtered_rows) == 0 + + +def test_read_with_limit(): + """Test limiting the number of rows returned.""" + db_path = get_db_path("test_read_with_limit") + # delete the db file if it exists + if os.path.exists(db_path): + os.remove(db_path) + store = SqliteEvaluationRowStore(db_path=db_path) + logger = SqliteDatasetLoggerAdapter(store=store) + + # Create 10 rows + for i in range(10): + messages = [Message(role="user", content=f"Hello {i}")] + input_metadata = InputMetadata(row_id=f"row_{i}") + row = EvaluationRow(input_metadata=input_metadata, messages=messages) + logger.log(row) + + # Test with limit + limited_rows = logger.read(limit=3) + assert len(limited_rows) == 3 + + # Verify we got the most recent rows (inserted last, returned first) + row_ids = [r.input_metadata.row_id for r in limited_rows] + assert row_ids == ["row_9", "row_8", "row_7"] # Most recent first diff --git a/tests/test_logs_server.py b/tests/test_logs_server.py index f17b3cf9..e8da6ca4 100644 --- a/tests/test_logs_server.py +++ b/tests/test_logs_server.py @@ -90,8 +90,11 @@ def test_broadcast_row_upserted(self): # Test that message is queued assert not manager._broadcast_queue.empty() - queued_message = manager._broadcast_queue.get_nowait() - data = json.loads(queued_message) + queued_item = manager._broadcast_queue.get_nowait() + # Queue item is now a tuple of (json_message, row_invocation_id) + assert isinstance(queued_item, tuple) + json_message, row_invocation_id = queued_item + data = json.loads(json_message) assert data["type"] == "log" assert "row" in data assert data["row"]["messages"][0]["content"] == "test" @@ -113,21 +116,22 @@ async def test_broadcast_loop(self): assert manager._broadcast_task is None @pytest.mark.asyncio - async def test_send_text_to_all_connections(self): - """Test sending text to all connections.""" + async def test_send_text_to_filtered_connections(self): + """Test sending text to filtered connections.""" manager = WebSocketManager() mock_websocket1 = AsyncMock() mock_websocket2 = AsyncMock() # Mock default_logger.read() to return empty logs + # Both connections have no filter (None), so they receive all messages with patch.object(default_logger, "read", return_value=[]): await manager.connect(mock_websocket1) await manager.connect(mock_websocket2) test_message = "test message" - await manager._send_text_to_all_connections(test_message) + await manager._send_text_to_filtered_connections(test_message) - # Check that the test message was sent to both websockets + # Check that the test message was sent to both websockets (no filter = receives all) mock_websocket1.send_text.assert_any_call(test_message) mock_websocket2.send_text.assert_any_call(test_message) @@ -151,7 +155,7 @@ async def failing_send_text(text): mock_websocket2.send_text = failing_send_text test_message = "test message" - await manager._send_text_to_all_connections(test_message) + await manager._send_text_to_filtered_connections(test_message) # First websocket should receive the message mock_websocket1.send_text.assert_any_call(test_message) @@ -159,6 +163,52 @@ async def failing_send_text(text): assert len(manager.active_connections) == 1 assert mock_websocket1 in manager.active_connections + @pytest.mark.asyncio + async def test_connect_with_invocation_ids_filter(self): + """Test connecting with invocation_ids filter.""" + manager = WebSocketManager() + mock_websocket = AsyncMock() + + # Mock default_logger.read() to verify filter is passed + with patch.object(default_logger, "read", return_value=[]) as mock_read: + await manager.connect(mock_websocket, invocation_ids=["inv-123", "inv-456"]) + + # Verify that read was called with the invocation_ids filter + mock_read.assert_called_once_with(invocation_ids=["inv-123", "inv-456"], limit=1000) + + # Verify that the connection has the filter stored + assert manager._connection_filters[mock_websocket] == ["inv-123", "inv-456"] + + @pytest.mark.asyncio + async def test_send_text_to_filtered_connections_respects_filter(self): + """Test that messages are only sent to connections matching the filter.""" + manager = WebSocketManager() + mock_websocket_all = AsyncMock() # No filter - receives all + mock_websocket_inv1 = AsyncMock() # Filter for inv-123 + mock_websocket_inv2 = AsyncMock() # Filter for inv-456 + + # Connect with different filters + with patch.object(default_logger, "read", return_value=[]): + await manager.connect(mock_websocket_all) # No filter + await manager.connect(mock_websocket_inv1, invocation_ids=["inv-123"]) + await manager.connect(mock_websocket_inv2, invocation_ids=["inv-456"]) + + # Reset mocks to clear the initial send_text calls from connect + mock_websocket_all.reset_mock() + mock_websocket_inv1.reset_mock() + mock_websocket_inv2.reset_mock() + + # Send a message for inv-123 + test_message = '{"type": "log", "row": {}}' + await manager._send_text_to_filtered_connections(test_message, row_invocation_id="inv-123") + + # mock_websocket_all should receive (no filter) + mock_websocket_all.send_text.assert_called_once_with(test_message) + # mock_websocket_inv1 should receive (filter matches) + mock_websocket_inv1.send_text.assert_called_once_with(test_message) + # mock_websocket_inv2 should NOT receive (filter doesn't match) + mock_websocket_inv2.send_text.assert_not_called() + class TestEvaluationWatcher: """Test EvaluationWatcher class.""" diff --git a/vite-app/src/App.tsx b/vite-app/src/App.tsx index fe55b4d3..629b2148 100644 --- a/vite-app/src/App.tsx +++ b/vite-app/src/App.tsx @@ -8,7 +8,11 @@ import { EvaluationRowSchema, type EvaluationRow } from "./types/eval-protocol"; import { WebSocketServerMessageSchema } from "./types/websocket"; import { GlobalState } from "./GlobalState"; import logoLight from "./assets/logo-light.png"; -import { getWebSocketUrl, discoverServerConfig } from "./config"; +import { + getWebSocketUrl, + discoverServerConfig, + extractInvocationIdsFromUrl, +} from "./config"; export const state = new GlobalState(); @@ -30,7 +34,19 @@ const App = observer(() => { return; // Already connected or connecting. This will happen in React strict mode. } - const ws = new WebSocket(getWebSocketUrl()); + // Extract invocation_ids from URL filter for server-side filtering + const invocationIds = extractInvocationIdsFromUrl(); + const wsUrl = getWebSocketUrl( + invocationIds.length > 0 ? invocationIds : undefined + ); + console.log( + "Connecting to WebSocket:", + wsUrl, + "with invocation_ids:", + invocationIds + ); + + const ws = new WebSocket(wsUrl); wsRef.current = ws; ws.onopen = () => { diff --git a/vite-app/src/config.ts b/vite-app/src/config.ts index 09a28c32..23a18bdc 100644 --- a/vite-app/src/config.ts +++ b/vite-app/src/config.ts @@ -14,10 +14,60 @@ export const config = { }, }; -// Helper function to build WebSocket URL -export const getWebSocketUrl = (): string => { +/** + * Extract invocation_ids from URL filterConfig parameter. + * Looks for filters on $.execution_metadata.invocation_id field. + */ +export const extractInvocationIdsFromUrl = (): string[] => { + try { + const urlParams = new URLSearchParams(window.location.search); + const filterConfigStr = urlParams.get('filterConfig'); + if (!filterConfigStr) { + return []; + } + + // Parse the filter config JSON + const filterConfig = JSON.parse(filterConfigStr); + const invocationIds: string[] = []; + + // filterConfig is an array of FilterGroups + if (Array.isArray(filterConfig)) { + for (const group of filterConfig) { + if (group.filters && Array.isArray(group.filters)) { + for (const filter of group.filters) { + // Check if this filter is on invocation_id field + if ( + filter.field === '$.execution_metadata.invocation_id' && + (filter.operator === '==' || filter.operator === 'equals') && + filter.value + ) { + invocationIds.push(filter.value); + } + } + } + } + } + + return invocationIds; + } catch (error) { + console.warn('Failed to extract invocation_ids from URL:', error); + return []; + } +}; + +// Helper function to build WebSocket URL with optional invocation_ids filter +export const getWebSocketUrl = (invocationIds?: string[]): string => { const { protocol, host, port } = config.websocket; - return `${protocol}://${host}:${port}/ws`; + const baseUrl = `${protocol}://${host}:${port}/ws`; + + // If invocation_ids provided, add as query param + if (invocationIds && invocationIds.length > 0) { + const params = new URLSearchParams(); + params.set('invocation_ids', invocationIds.join(',')); + return `${baseUrl}?${params.toString()}`; + } + + return baseUrl; }; // Helper function to build API URL