diff --git a/packages/pynumaflow/pynumaflow/mapper/async_server.py b/packages/pynumaflow/pynumaflow/mapper/async_server.py index 3ef1ed46..34c9b92a 100644 --- a/packages/pynumaflow/pynumaflow/mapper/async_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/async_server.py @@ -60,6 +60,7 @@ def __init__( max_message_size=MAX_MESSAGE_SIZE, max_threads=NUM_THREADS_DEFAULT, server_info_file=MAP_SERVER_INFO_FILE_PATH, + shutdown_callback=None, ): """ Create a new grpc Asynchronous Map Server instance. @@ -77,6 +78,7 @@ def __init__( self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file + self.shutdwon_callback = shutdown_callback self.mapper_instance = mapper_instance @@ -92,7 +94,7 @@ def start(self) -> None: Starter function for the Async server class, need a separate caller so that all the async coroutines can be started from a single context """ - aiorun.run(self.aexec(), use_uvloop=True) + aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdwon_callback) async def aexec(self) -> None: """ diff --git a/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py b/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py index d718a6a5..90fdd3c7 100644 --- a/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py @@ -37,6 +37,7 @@ def __init__( max_message_size=MAX_MESSAGE_SIZE, max_threads=NUM_THREADS_DEFAULT, server_info_file=MAP_SERVER_INFO_FILE_PATH, + shutdown_callback=None, ): """ Create a new grpc Async Map Stream Server instance. @@ -98,6 +99,7 @@ async def map_stream_handler(_: list[str], datum: Datum) -> AsyncIterable[Messag self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file + self.shutdwon_callback = shutdown_callback self._server_options = [ ("grpc.max_send_message_length", self.max_message_size), @@ -111,7 +113,7 @@ def start(self): Starter function for the Async Map Stream server, we need a separate caller to the aexec so that all the async coroutines can be started from a single context """ - aiorun.run(self.aexec(), use_uvloop=True) + aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdwon_callback) async def aexec(self): """ diff --git a/packages/pynumaflow/pynumaflow/reducer/async_server.py b/packages/pynumaflow/pynumaflow/reducer/async_server.py index 4103fe98..8f6d06e7 100644 --- a/packages/pynumaflow/pynumaflow/reducer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducer/async_server.py @@ -124,6 +124,7 @@ def __init__( max_message_size=MAX_MESSAGE_SIZE, max_threads=NUM_THREADS_DEFAULT, server_info_file=REDUCE_SERVER_INFO_FILE_PATH, + shutdown_callback=None, ): init_kwargs = init_kwargs or {} self.reducer_handler = get_handler(reducer_instance, init_args, init_kwargs) @@ -131,6 +132,7 @@ def __init__( self.max_message_size = max_message_size self.max_threads = min(max_threads, MAX_NUM_THREADS) self.server_info_file = server_info_file + self.shutdwon_callback = shutdown_callback self._server_options = [ ("grpc.max_send_message_length", self.max_message_size), @@ -147,7 +149,7 @@ def start(self): _LOGGER.info( "Starting Async Reduce Server", ) - aiorun.run(self.aexec(), use_uvloop=True) + aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdwon_callback) async def aexec(self): """ diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py index f974c1a0..771b5f94 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py @@ -138,6 +138,7 @@ def __init__( max_message_size=MAX_MESSAGE_SIZE, max_threads=NUM_THREADS_DEFAULT, server_info_file=REDUCE_STREAM_SERVER_INFO_FILE_PATH, + shutdown_callback=None, ): init_kwargs = init_kwargs or {} self.reduce_stream_handler = get_handler(reduce_stream_instance, init_args, init_kwargs) @@ -145,6 +146,7 @@ def __init__( self.max_message_size = max_message_size self.max_threads = min(max_threads, MAX_NUM_THREADS) self.server_info_file = server_info_file + self.shutdwon_callback = shutdown_callback self._server_options = [ ("grpc.max_send_message_length", self.max_message_size), @@ -161,7 +163,7 @@ def start(self): _LOGGER.info( "Starting Async Reduce Stream Server", ) - aiorun.run(self.aexec(), use_uvloop=True) + aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdwon_callback) async def aexec(self): """ diff --git a/packages/pynumaflow/pynumaflow/sinker/async_server.py b/packages/pynumaflow/pynumaflow/sinker/async_server.py index 40020ced..3628f01c 100644 --- a/packages/pynumaflow/pynumaflow/sinker/async_server.py +++ b/packages/pynumaflow/pynumaflow/sinker/async_server.py @@ -88,6 +88,7 @@ def __init__( max_message_size=MAX_MESSAGE_SIZE, max_threads=NUM_THREADS_DEFAULT, server_info_file=SINK_SERVER_INFO_FILE_PATH, + shutdown_callback=None, ): # If the container type is fallback sink, then use the fallback sink address and path. if os.getenv(ENV_UD_CONTAINER_TYPE, "") == UD_CONTAINER_FALLBACK_SINK: @@ -103,6 +104,7 @@ def __init__( self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file + self.shutdwon_callback = shutdown_callback self.sinker_instance = sinker_instance @@ -118,7 +120,7 @@ def start(self): Starter function for the Async server class, need a separate caller so that all the async coroutines can be started from a single context """ - aiorun.run(self.aexec(), use_uvloop=True) + aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdwon_callback) async def aexec(self): """ diff --git a/packages/pynumaflow/pynumaflow/sourcer/async_server.py b/packages/pynumaflow/pynumaflow/sourcer/async_server.py index 7e312213..3da75cf8 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/async_server.py +++ b/packages/pynumaflow/pynumaflow/sourcer/async_server.py @@ -29,6 +29,7 @@ def __init__( max_message_size=MAX_MESSAGE_SIZE, max_threads=NUM_THREADS_DEFAULT, server_info_file=SOURCE_SERVER_INFO_FILE_PATH, + shutdown_callback=None, ): """ Create a new grpc Async Source Server instance. @@ -138,6 +139,7 @@ async def partitions_handler(self) -> PartitionsResponse: self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file + self.shutdown_callback = shutdown_callback self.sourcer_instance = sourcer_instance @@ -153,7 +155,7 @@ def start(self): Starter function for the Async server class, need a separate caller so that all the async coroutines can be started from a single context """ - aiorun.run(self.aexec(), use_uvloop=True) + aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdwon_callback) async def aexec(self): """