From 3e287ff1bb91db0a82acb4ec1d27685214dc934b Mon Sep 17 00:00:00 2001 From: leavesster <11785335+leavesster@users.noreply.github.com> Date: Sat, 31 Jan 2026 14:14:53 +0800 Subject: [PATCH] fix: replace busy-wait with threading.Event in notify_block_ready The previous implementation used a while True loop that would spin the CPU waiting for a response. This is replaced with threading.Event.wait() which properly blocks without consuming CPU resources. Changes: - Add threading.Event to wait for message response - Add optional timeout parameter to prevent indefinite blocking - Raise TimeoutError with descriptive message on timeout - Properly clean up subscription on timeout - Add comprehensive unit tests for the new implementation --- oocana/oocana/mainframe.py | 30 +++++- oocana/tests/test_mainframe_notify.py | 140 ++++++++++++++++++++++++++ 2 files changed, 165 insertions(+), 5 deletions(-) create mode 100644 oocana/tests/test_mainframe_notify.py diff --git a/oocana/oocana/mainframe.py b/oocana/oocana/mainframe.py index 3f7083e6..6c2d4d33 100644 --- a/oocana/oocana/mainframe.py +++ b/oocana/oocana/mainframe.py @@ -4,6 +4,7 @@ import operator from urllib.parse import urlparse import uuid +import threading from .data import BlockDict, JobDict, dumps, EXECUTOR_NAME import logging from typing import Optional, Callable, Any @@ -104,15 +105,30 @@ def notify_executor_ready(self, session_id: str, package: str | None, identifier "debug_port": debug_port, }), qos=1) - def notify_block_ready(self, session_id: str, job_id: str) -> dict: + def notify_block_ready(self, session_id: str, job_id: str, timeout: Optional[float] = None) -> dict: + """ + Notify that a block is ready and wait for input message. + Args: + session_id: The session ID + job_id: The job ID + timeout: Optional timeout in seconds. If None, wait indefinitely. + + Returns: + The input message payload as a dict + + Raises: + TimeoutError: If timeout is specified and no message is received within the timeout + """ topic = f"inputs/{session_id}/{job_id}" replay = None + event = threading.Event() def on_message_once(_client, _userdata, message): nonlocal replay self.client.unsubscribe(topic) replay = loads(message.payload) + event.set() self.client.subscribe(topic, qos=1) self.client.message_callback_add(topic, on_message_once) @@ -123,10 +139,14 @@ def on_message_once(_client, _userdata, message): "job_id": job_id, }), qos=1) - while True: - if replay is not None: - self._logger.info("notify ready success in {} {}".format(session_id, job_id)) - return replay + if event.wait(timeout=timeout): + self._logger.info("notify ready success in {} {}".format(session_id, job_id)) + return replay # type: ignore + else: + # Timeout occurred, clean up subscription + self.client.unsubscribe(topic) + self.client.message_callback_remove(topic) + raise TimeoutError(f"Timeout waiting for block ready response in session {session_id}, job {job_id}") def add_request_response_callback(self, session_id: str, request_id: str, callback: Callable[[Any], Any]): """Add a callback to be called when an error occurs while running a block.""" diff --git a/oocana/tests/test_mainframe_notify.py b/oocana/tests/test_mainframe_notify.py new file mode 100644 index 00000000..aaa38d5a --- /dev/null +++ b/oocana/tests/test_mainframe_notify.py @@ -0,0 +1,140 @@ +import unittest +from unittest.mock import MagicMock, patch, PropertyMock +import threading +import time +from oocana import Mainframe + + +class MockMessage: + """Mock MQTT message for testing.""" + def __init__(self, payload: bytes): + self.payload = payload + + +class TestNotifyBlockReady(unittest.TestCase): + """Test cases for Mainframe.notify_block_ready method.""" + + def setUp(self): + # Patch the mqtt client to avoid real network connections + self.mock_client_patcher = patch('paho.mqtt.client.Client') + self.mock_client_class = self.mock_client_patcher.start() + self.mock_client = MagicMock() + self.mock_client_class.return_value = self.mock_client + self.mock_client.is_connected.return_value = True + + self.mainframe = Mainframe('mqtt://localhost:1883') + self.mainframe.client = self.mock_client + + def tearDown(self): + self.mock_client_patcher.stop() + + def test_notify_block_ready_receives_response(self): + """Test that notify_block_ready correctly waits for and returns response.""" + session_id = 'test-session' + job_id = 'test-job' + expected_payload = {'inputs': {'key': 'value'}} + + # Simulate message callback being triggered + def trigger_callback(*args, **kwargs): + # Get the callback that was registered + callback = self.mock_client.message_callback_add.call_args[0][1] + # Create a mock message + import simplejson + mock_message = MockMessage(simplejson.dumps(expected_payload).encode()) + # Trigger the callback in a separate thread + callback(None, None, mock_message) + + # Make subscribe trigger the callback after a short delay + def delayed_trigger(*args, **kwargs): + timer = threading.Timer(0.1, trigger_callback) + timer.start() + + self.mock_client.publish.side_effect = delayed_trigger + + result = self.mainframe.notify_block_ready(session_id, job_id, timeout=5.0) + + self.assertEqual(result, expected_payload) + self.mock_client.subscribe.assert_called_once() + self.mock_client.publish.assert_called_once() + + def test_notify_block_ready_timeout(self): + """Test that notify_block_ready raises TimeoutError on timeout.""" + session_id = 'test-session' + job_id = 'test-job' + + # Don't trigger any callback, let it timeout + with self.assertRaises(TimeoutError) as context: + self.mainframe.notify_block_ready(session_id, job_id, timeout=0.1) + + self.assertIn(session_id, str(context.exception)) + self.assertIn(job_id, str(context.exception)) + # Verify cleanup was called + self.mock_client.unsubscribe.assert_called() + self.mock_client.message_callback_remove.assert_called() + + def test_notify_block_ready_unsubscribes_on_success(self): + """Test that the topic is unsubscribed after successful message receipt.""" + session_id = 'test-session' + job_id = 'test-job' + expected_topic = f"inputs/{session_id}/{job_id}" + + def trigger_callback(*args, **kwargs): + callback = self.mock_client.message_callback_add.call_args[0][1] + import simplejson + mock_message = MockMessage(simplejson.dumps({}).encode()) + callback(None, None, mock_message) + + self.mock_client.publish.side_effect = trigger_callback + + self.mainframe.notify_block_ready(session_id, job_id, timeout=5.0) + + # Verify unsubscribe was called with the correct topic + self.mock_client.unsubscribe.assert_called_with(expected_topic) + + def test_notify_block_ready_publishes_correct_message(self): + """Test that the correct BlockReady message is published.""" + session_id = 'test-session' + job_id = 'test-job' + + def trigger_callback(*args, **kwargs): + callback = self.mock_client.message_callback_add.call_args[0][1] + import simplejson + mock_message = MockMessage(simplejson.dumps({}).encode()) + callback(None, None, mock_message) + + self.mock_client.publish.side_effect = trigger_callback + + self.mainframe.notify_block_ready(session_id, job_id, timeout=5.0) + + # Check that publish was called with correct topic and payload + publish_call = self.mock_client.publish.call_args + self.assertEqual(publish_call[0][0], f"session/{session_id}") + + import simplejson + payload = simplejson.loads(publish_call[0][1]) + self.assertEqual(payload['type'], 'BlockReady') + self.assertEqual(payload['session_id'], session_id) + self.assertEqual(payload['job_id'], job_id) + + def test_notify_block_ready_no_cpu_spin(self): + """Test that notify_block_ready does not spin CPU while waiting.""" + session_id = 'test-session' + job_id = 'test-job' + + start_time = time.time() + + # Use a short timeout to verify it actually waits + try: + self.mainframe.notify_block_ready(session_id, job_id, timeout=0.2) + except TimeoutError: + pass + + elapsed_time = time.time() - start_time + + # If it was busy-waiting, it would return almost immediately + # With Event.wait(), it should wait close to the timeout + self.assertGreaterEqual(elapsed_time, 0.15) + + +if __name__ == '__main__': + unittest.main()