From e49d3227f0ada17af8857cae30ab56cd8526ba58 Mon Sep 17 00:00:00 2001 From: srao12 Date: Mon, 2 Feb 2026 12:35:57 -0800 Subject: [PATCH] feat: add metadata support for source transformer Signed-off-by: srao12 --- .../pynumaflow/sourcetransformer/__init__.py | 3 + .../pynumaflow/sourcetransformer/_dtypes.py | 45 +++++- .../servicer/_async_servicer.py | 7 + .../sourcetransformer/servicer/_servicer.py | 7 + .../tests/sourcetransform/test_async.py | 117 ++++++++++++++ .../tests/sourcetransform/test_messages.py | 68 ++++++++- .../tests/sourcetransform/test_sync_server.py | 79 +++++++++- .../pynumaflow/tests/sourcetransform/utils.py | 143 ++++++++++++++---- 8 files changed, 434 insertions(+), 35 deletions(-) diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/__init__.py b/packages/pynumaflow/pynumaflow/sourcetransformer/__init__.py index 8eee3786..f029c6ef 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/__init__.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/__init__.py @@ -8,6 +8,7 @@ from pynumaflow.sourcetransformer.multiproc_server import SourceTransformMultiProcServer from pynumaflow.sourcetransformer.server import SourceTransformServer from pynumaflow.sourcetransformer.async_server import SourceTransformAsyncServer +from pynumaflow._metadata import UserMetadata, SystemMetadata __all__ = [ "Message", @@ -18,4 +19,6 @@ "SourceTransformer", "SourceTransformMultiProcServer", "SourceTransformAsyncServer", + "UserMetadata", + "SystemMetadata", ] diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py b/packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py index 28591000..c7f74446 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py @@ -7,6 +7,7 @@ from warnings import warn from pynumaflow._constants import DROP +from pynumaflow._metadata import UserMetadata, SystemMetadata M = TypeVar("M", bound="Message") Ms = TypeVar("Ms", bound="Messages") @@ -22,17 +23,24 @@ class Message: event_time: event time of the message, usually extracted from the payload. keys: []string keys for vertex (optional) tags: []string tags for conditional forwarding (optional) + user_metadata: metadata for the message (optional) """ - __slots__ = ("_value", "_keys", "_tags", "_event_time") + __slots__ = ("_value", "_keys", "_tags", "_event_time", "_user_metadata") _keys: list[str] _tags: list[str] _value: bytes _event_time: datetime + _user_metadata: UserMetadata def __init__( - self, value: bytes, event_time: datetime, keys: list[str] = None, tags: list[str] = None + self, + value: bytes, + event_time: datetime, + keys: list[str] = None, + tags: list[str] = None, + user_metadata: Optional[UserMetadata] = None, ): """ Creates a Message object to send value to a vertex. @@ -43,6 +51,7 @@ def __init__( # There is no year 0, so setting following as default event time. self._event_time = event_time or datetime(1, 1, 1, 0, 0) self._value = value or b"" + self._user_metadata = user_metadata or UserMetadata() @classmethod def to_drop(cls: type[M], event_time: datetime) -> M: @@ -64,6 +73,10 @@ def value(self) -> bytes: def tags(self) -> list[str]: return self._tags + @property + def user_metadata(self) -> UserMetadata: + return self._user_metadata + class Messages(Sequence[M]): """ @@ -119,6 +132,8 @@ class Datum: event_time: the event time of the event. watermark: the watermark of the event. headers: the headers of the event. + user_metadata: the user metadata of the event. + system_metadata: the system metadata of the event. Example: ```py @@ -135,13 +150,23 @@ class Datum: ``` """ - __slots__ = ("_keys", "_value", "_event_time", "_watermark", "_headers") + __slots__ = ( + "_keys", + "_value", + "_event_time", + "_watermark", + "_headers", + "_user_metadata", + "_system_metadata", + ) _keys: list[str] _value: bytes _event_time: datetime _watermark: datetime _headers: dict[str, str] + _user_metadata: UserMetadata + _system_metadata: SystemMetadata def __init__( self, @@ -150,6 +175,8 @@ def __init__( event_time: datetime, watermark: datetime, headers: Optional[dict[str, str]] = None, + user_metadata: Optional[UserMetadata] = None, + system_metadata: Optional[SystemMetadata] = None, ): self._keys = keys or list() self._value = value or b"" @@ -160,6 +187,8 @@ def __init__( raise TypeError(f"Wrong data type: {type(watermark)} for Datum.watermark") self._watermark = watermark self._headers = headers or {} + self._user_metadata = user_metadata or UserMetadata() + self._system_metadata = system_metadata or SystemMetadata() @property def keys(self) -> list[str]: @@ -186,6 +215,16 @@ def headers(self) -> dict[str, str]: """Returns the headers of the event.""" return self._headers.copy() + @property + def user_metadata(self) -> UserMetadata: + """Returns the user metadata of the event.""" + return self._user_metadata + + @property + def system_metadata(self) -> SystemMetadata: + """Returns the system metadata of the event.""" + return self._system_metadata + class SourceTransformer(metaclass=ABCMeta): """ diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py index b2e70799..7f0047c5 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py @@ -5,6 +5,7 @@ from google.protobuf import timestamp_pb2 as _timestamp_pb2 from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING +from pynumaflow._metadata import _user_and_system_metadata_from_proto from pynumaflow.proto.sourcetransformer import transform_pb2, transform_pb2_grpc from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow.shared.server import handle_async_error @@ -105,12 +106,17 @@ async def _invoke_transform( Invokes the user defined function. """ try: + user_metadata, system_metadata = _user_and_system_metadata_from_proto( + request.request.metadata + ) datum = Datum( keys=list(request.request.keys), value=request.request.value, event_time=request.request.event_time.ToDatetime(), watermark=request.request.watermark.ToDatetime(), headers=dict(request.request.headers), + user_metadata=user_metadata, + system_metadata=system_metadata, ) msgs = await self.__transform_handler(list(request.request.keys), datum) results = [] @@ -123,6 +129,7 @@ async def _invoke_transform( value=msg.value, tags=msg.tags, event_time=event_time_timestamp, + metadata=msg.user_metadata._to_proto(), ) ) await result_queue.put( diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py index 4aea5196..e3c821f7 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py @@ -9,6 +9,7 @@ from pynumaflow.shared.synciter import SyncIterator from pynumaflow.sourcetransformer import Datum from pynumaflow.sourcetransformer._dtypes import SourceTransformCallable +from pynumaflow._metadata import _user_and_system_metadata_from_proto from pynumaflow.proto.sourcetransformer import transform_pb2 from pynumaflow.proto.sourcetransformer import transform_pb2_grpc from pynumaflow.types import NumaflowServicerContext @@ -119,12 +120,17 @@ def _invoke_transformer( self, context, request: transform_pb2.SourceTransformRequest, result_queue: SyncIterator ): try: + user_metadata, system_metadata = _user_and_system_metadata_from_proto( + request.request.metadata + ) d = Datum( keys=list(request.request.keys), value=request.request.value, event_time=request.request.event_time.ToDatetime(), watermark=request.request.watermark.ToDatetime(), headers=dict(request.request.headers), + user_metadata=user_metadata, + system_metadata=system_metadata, ) responses = self.__transform_handler(list(request.request.keys), d) @@ -138,6 +144,7 @@ def _invoke_transformer( value=resp.value, tags=resp.tags, event_time=event_time_timestamp, + metadata=resp.user_metadata._to_proto(), ) ) result_queue.put( diff --git a/packages/pynumaflow/tests/sourcetransform/test_async.py b/packages/pynumaflow/tests/sourcetransform/test_async.py index 05f7f29d..bcdab288 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_async.py +++ b/packages/pynumaflow/tests/sourcetransform/test_async.py @@ -11,6 +11,7 @@ from pynumaflow import setup_logging from pynumaflow._constants import MAX_MESSAGE_SIZE +from pynumaflow.proto.common import metadata_pb2 from pynumaflow.proto.sourcetransformer import transform_pb2_grpc from pynumaflow.sourcetransformer import Datum, Messages, Message, SourceTransformer from pynumaflow.sourcetransformer.async_server import SourceTransformAsyncServer @@ -267,6 +268,122 @@ def test_max_threads(self): self.assertEqual(server.max_threads, 4) +class MetadataAsyncSourceTransformer(SourceTransformer): + """Source transformer that validates and passes through metadata.""" + + async def handler(self, keys: list[str], datum: Datum) -> Messages: + # Validate system metadata + if datum.system_metadata.value("numaflow_version_info", "version") != b"1.0.0": + raise ValueError("System metadata version mismatch") + + val = datum.value + msg = "payload:{} event_time:{} ".format( + val.decode("utf-8"), + datum.event_time, + ) + val = bytes(msg, encoding="utf-8") + messages = Messages() + # Pass user metadata to the output message + messages.append( + Message(val, mock_new_event_time(), keys=keys, user_metadata=datum.user_metadata) + ) + return messages + + +_metadata_s: Server = None +_metadata_channel = grpc.insecure_channel("unix:///tmp/async_st_metadata.sock") +_metadata_loop = None + + +def metadata_startup_callable(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + +def new_metadata_async_st(): + handle = MetadataAsyncSourceTransformer() + server = SourceTransformAsyncServer(source_transform_instance=handle) + return server.servicer + + +async def start_metadata_server(udfs): + _server_options = [ + ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), + ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), + ] + server = grpc.aio.server(options=_server_options) + transform_pb2_grpc.add_SourceTransformServicer_to_server(udfs, server) + listen_addr = "unix:///tmp/async_st_metadata.sock" + server.add_insecure_port(listen_addr) + logging.info("Starting metadata server on %s", listen_addr) + global _metadata_s + _metadata_s = server + await server.start() + await server.wait_for_termination() + + +@patch("psutil.Process.kill", mock_terminate_on_stop) +class TestAsyncTransformerMetadata(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + global _metadata_loop + loop = asyncio.new_event_loop() + _metadata_loop = loop + _thread = threading.Thread(target=metadata_startup_callable, args=(loop,), daemon=True) + _thread.start() + udfs = new_metadata_async_st() + asyncio.run_coroutine_threadsafe(start_metadata_server(udfs), loop=loop) + while True: + try: + with grpc.insecure_channel("unix:///tmp/async_st_metadata.sock") as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") + LOGGER.error(e) + + @classmethod + def tearDownClass(cls) -> None: + try: + _metadata_loop.stop() + LOGGER.info("stopped the metadata event loop") + except Exception as e: + LOGGER.error(e) + + def test_source_transformer_with_metadata(self) -> None: + stub = transform_pb2_grpc.SourceTransformStub(_metadata_channel) + request = get_test_datums(with_metadata=True) + generator_response = None + try: + generator_response = stub.SourceTransformFn(request_iterator=request_generator(request)) + except grpc.RpcError as e: + logging.error(e) + raise + + responses = [] + for r in generator_response: + responses.append(r) + + # 1 handshake + 3 data responses + self.assertEqual(4, len(responses)) + self.assertTrue(responses[0].handshake.sot) + + # Verify metadata is passed through correctly + for idx, resp in enumerate(responses[1:], 1): + _id = "test-id-" + str(idx) + self.assertEqual(_id, resp.id) + self.assertEqual(1, len(resp.results)) + # Verify user metadata is returned + self.assertEqual( + resp.results[0].metadata.user_metadata["custom_info"], + metadata_pb2.KeyValueGroup(key_value={"version": f"{idx}.0.0".encode()}), + ) + # System metadata should be empty in responses (user cannot set it) + self.assertEqual(resp.results[0].metadata.sys_metadata, {}) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() diff --git a/packages/pynumaflow/tests/sourcetransform/test_messages.py b/packages/pynumaflow/tests/sourcetransform/test_messages.py index 9f6baceb..eb8124c5 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_messages.py +++ b/packages/pynumaflow/tests/sourcetransform/test_messages.py @@ -1,7 +1,15 @@ import unittest from datetime import datetime, timezone -from pynumaflow.sourcetransformer import Messages, Message, DROP, SourceTransformer, Datum +from pynumaflow.sourcetransformer import ( + Messages, + Message, + DROP, + SourceTransformer, + Datum, + UserMetadata, + SystemMetadata, +) from tests.testing_utils import mock_new_event_time @@ -45,6 +53,28 @@ def test_message_to_drop(self): self.assertEqual(mock_obj["Tags"], msgt.tags) self.assertEqual(mock_obj["EventTime"], msgt.event_time) + def test_message_with_user_metadata(self): + user_meta = UserMetadata() + user_meta.add_key("group1", "key1", b"value1") + user_meta.add_key("group1", "key2", b"value2") + + msgt = Message( + mock_message_t(), + mock_event_time(), + keys=["test_key"], + user_metadata=user_meta, + ) + self.assertEqual(mock_message_t(), msgt.value) + self.assertEqual(["test_key"], msgt.keys) + self.assertEqual(b"value1", msgt.user_metadata.value("group1", "key1")) + self.assertEqual(b"value2", msgt.user_metadata.value("group1", "key2")) + self.assertEqual(["group1"], msgt.user_metadata.groups()) + + def test_message_default_user_metadata(self): + msgt = Message(mock_message_t(), mock_event_time()) + self.assertIsNotNone(msgt.user_metadata) + self.assertEqual(0, len(msgt.user_metadata)) + class TestMessages(unittest.TestCase): @staticmethod @@ -94,6 +124,42 @@ def test_err(self): msgts[:1] +class TestDatum(unittest.TestCase): + def test_datum_with_metadata(self): + user_meta = UserMetadata() + user_meta.add_key("group1", "key1", b"value1") + + sys_meta = SystemMetadata({"sys_group": {"sys_key": b"sys_value"}}) + + d = Datum( + keys=["test_key"], + value=mock_message_t(), + event_time=mock_event_time(), + watermark=mock_event_time(), + headers={"header1": "value1"}, + user_metadata=user_meta, + system_metadata=sys_meta, + ) + self.assertEqual(["test_key"], d.keys) + self.assertEqual(mock_message_t(), d.value) + self.assertEqual(mock_event_time(), d.event_time) + self.assertEqual({"header1": "value1"}, d.headers) + self.assertEqual(b"value1", d.user_metadata.value("group1", "key1")) + self.assertEqual(b"sys_value", d.system_metadata.value("sys_group", "sys_key")) + + def test_datum_default_metadata(self): + d = Datum( + keys=["test_key"], + value=mock_message_t(), + event_time=mock_event_time(), + watermark=mock_event_time(), + ) + self.assertIsNotNone(d.user_metadata) + self.assertIsNotNone(d.system_metadata) + self.assertEqual(0, len(d.user_metadata)) + self.assertEqual([], d.system_metadata.groups()) + + class ExampleSourceTransformClass(SourceTransformer): def handler(self, keys: list[str], datum: Datum) -> Messages: messages = Messages() diff --git a/packages/pynumaflow/tests/sourcetransform/test_sync_server.py b/packages/pynumaflow/tests/sourcetransform/test_sync_server.py index 3ffcd0da..2576c97f 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_sync_server.py +++ b/packages/pynumaflow/tests/sourcetransform/test_sync_server.py @@ -7,8 +7,9 @@ from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time +from pynumaflow.proto.common import metadata_pb2 from pynumaflow.proto.sourcetransformer import transform_pb2 -from pynumaflow.sourcetransformer import SourceTransformServer +from pynumaflow.sourcetransformer import SourceTransformServer, Datum, Messages, Message from tests.sourcetransform.utils import transform_handler, err_transform_handler, get_test_datums from tests.testing_utils import ( mock_terminate_on_stop, @@ -195,5 +196,81 @@ def test_max_threads(self): self.assertEqual(server.max_threads, 4) +def metadata_transform_handler(keys: list[str], datum: Datum) -> Messages: + """Handler that validates system metadata and passes through user metadata.""" + if datum.system_metadata.value("numaflow_version_info", "version") != b"1.0.0": + raise ValueError("System metadata version mismatch") + + val = datum.value + msg = "payload:{} event_time:{} ".format( + val.decode("utf-8"), + datum.event_time, + ) + val = bytes(msg, encoding="utf-8") + messages = Messages() + messages.append( + Message(val, mock_new_event_time(), keys=keys, user_metadata=datum.user_metadata) + ) + return messages + + +@patch("psutil.Process.kill", mock_terminate_on_stop) +class TestServerMetadata(unittest.TestCase): + def setUp(self) -> None: + server = SourceTransformServer(source_transform_instance=metadata_transform_handler) + my_servicer = server.servicer + services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: my_servicer} + self.test_server = server_from_dictionary(services, strict_real_time()) + + def test_source_transform_with_metadata(self): + test_datums = get_test_datums(with_metadata=True) + + method = self.test_server.invoke_stream_stream( + method_descriptor=( + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ + "SourceTransformFn" + ] + ), + invocation_metadata={}, + timeout=1, + ) + + for x in test_datums: + method.send_request(x) + method.requests_closed() + + responses = [] + while True: + try: + resp = method.take_response() + responses.append(resp) + except ValueError as err: + if "No more responses!" in err.__str__(): + break + + metadata, code, details = method.termination() + + # 1 handshake + 3 data responses + self.assertEqual(4, len(responses)) + self.assertTrue(responses[0].handshake.sot) + + # Verify metadata is passed through correctly + result_metadata = {} + for resp in responses[1:]: + result_metadata[resp.id] = resp.results[0].metadata + + for idx in range(1, 4): + _id = f"test-id-{idx}" + self.assertIn(_id, result_metadata) + self.assertEqual( + result_metadata[_id].user_metadata["custom_info"], + metadata_pb2.KeyValueGroup(key_value={"version": f"{idx}.0.0".encode()}), + ) + # System metadata should be empty in responses + self.assertEqual(result_metadata[_id].sys_metadata, {}) + + self.assertEqual(code, StatusCode.OK) + + if __name__ == "__main__": unittest.main() diff --git a/packages/pynumaflow/tests/sourcetransform/utils.py b/packages/pynumaflow/tests/sourcetransform/utils.py index 03e5d861..cec0b895 100644 --- a/packages/pynumaflow/tests/sourcetransform/utils.py +++ b/packages/pynumaflow/tests/sourcetransform/utils.py @@ -1,3 +1,4 @@ +from pynumaflow.proto.common import metadata_pb2 from pynumaflow.proto.sourcetransformer import transform_pb2 from pynumaflow.sourcetransformer import Datum, Messages, Message from tests.testing_utils import mock_new_event_time, mock_message, get_time_args @@ -19,7 +20,7 @@ def err_transform_handler(_: list[str], __: Datum) -> Messages: raise RuntimeError("Something is fishy!") -def get_test_datums(handshake=True): +def get_test_datums(handshake=True, with_metadata=False): event_time_timestamp, watermark_timestamp = get_time_args() responses = [] @@ -31,35 +32,117 @@ def get_test_datums(handshake=True): ) ) - test_datum = [ - transform_pb2.SourceTransformRequest( - request=transform_pb2.SourceTransformRequest.Request( - keys=["test"], - value=mock_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - id="test-id-1", - ) - ), - transform_pb2.SourceTransformRequest( - request=transform_pb2.SourceTransformRequest.Request( - keys=["test"], - value=mock_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - id="test-id-2", - ) - ), - transform_pb2.SourceTransformRequest( - request=transform_pb2.SourceTransformRequest.Request( - keys=["test"], - value=mock_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - id="test-id-3", - ) - ), - ] + if with_metadata: + test_datum = [ + transform_pb2.SourceTransformRequest( + request=transform_pb2.SourceTransformRequest.Request( + keys=["test"], + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + id="test-id-1", + metadata=metadata_pb2.Metadata( + previous_vertex="test-source", + sys_metadata={ + "numaflow_version_info": metadata_pb2.KeyValueGroup( + key_value={ + "version": b"1.0.0", + } + ), + }, + user_metadata={ + "custom_info": metadata_pb2.KeyValueGroup( + key_value={ + "version": b"1.0.0", + } + ), + }, + ), + ) + ), + transform_pb2.SourceTransformRequest( + request=transform_pb2.SourceTransformRequest.Request( + keys=["test"], + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + id="test-id-2", + metadata=metadata_pb2.Metadata( + previous_vertex="test-source", + sys_metadata={ + "numaflow_version_info": metadata_pb2.KeyValueGroup( + key_value={ + "version": b"1.0.0", + } + ), + }, + user_metadata={ + "custom_info": metadata_pb2.KeyValueGroup( + key_value={ + "version": b"2.0.0", + } + ), + }, + ), + ) + ), + transform_pb2.SourceTransformRequest( + request=transform_pb2.SourceTransformRequest.Request( + keys=["test"], + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + id="test-id-3", + metadata=metadata_pb2.Metadata( + previous_vertex="test-source", + sys_metadata={ + "numaflow_version_info": metadata_pb2.KeyValueGroup( + key_value={ + "version": b"1.0.0", + } + ), + }, + user_metadata={ + "custom_info": metadata_pb2.KeyValueGroup( + key_value={ + "version": b"3.0.0", + } + ), + }, + ), + ) + ), + ] + else: + test_datum = [ + transform_pb2.SourceTransformRequest( + request=transform_pb2.SourceTransformRequest.Request( + keys=["test"], + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + id="test-id-1", + ) + ), + transform_pb2.SourceTransformRequest( + request=transform_pb2.SourceTransformRequest.Request( + keys=["test"], + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + id="test-id-2", + ) + ), + transform_pb2.SourceTransformRequest( + request=transform_pb2.SourceTransformRequest.Request( + keys=["test"], + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + id="test-id-3", + ) + ), + ] for x in test_datum: responses.append(x) return responses