Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pynumaflow.sourcetransformer.multiproc_server import SourceTransformMultiProcServer
from pynumaflow.sourcetransformer.server import SourceTransformServer
from pynumaflow.sourcetransformer.async_server import SourceTransformAsyncServer
from pynumaflow._metadata import UserMetadata, SystemMetadata

__all__ = [
"Message",
Expand All @@ -18,4 +19,6 @@
"SourceTransformer",
"SourceTransformMultiProcServer",
"SourceTransformAsyncServer",
"UserMetadata",
"SystemMetadata",
]
45 changes: 42 additions & 3 deletions packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from warnings import warn

from pynumaflow._constants import DROP
from pynumaflow._metadata import UserMetadata, SystemMetadata

M = TypeVar("M", bound="Message")
Ms = TypeVar("Ms", bound="Messages")
Expand All @@ -22,17 +23,24 @@ class Message:
event_time: event time of the message, usually extracted from the payload.
keys: []string keys for vertex (optional)
tags: []string tags for conditional forwarding (optional)
user_metadata: metadata for the message (optional)
"""

__slots__ = ("_value", "_keys", "_tags", "_event_time")
__slots__ = ("_value", "_keys", "_tags", "_event_time", "_user_metadata")

_keys: list[str]
_tags: list[str]
_value: bytes
_event_time: datetime
_user_metadata: UserMetadata

def __init__(
self, value: bytes, event_time: datetime, keys: list[str] = None, tags: list[str] = None
self,
value: bytes,
event_time: datetime,
keys: list[str] = None,
tags: list[str] = None,
user_metadata: Optional[UserMetadata] = None,
):
"""
Creates a Message object to send value to a vertex.
Expand All @@ -43,6 +51,7 @@ def __init__(
# There is no year 0, so setting following as default event time.
self._event_time = event_time or datetime(1, 1, 1, 0, 0)
self._value = value or b""
self._user_metadata = user_metadata or UserMetadata()

@classmethod
def to_drop(cls: type[M], event_time: datetime) -> M:
Expand All @@ -64,6 +73,10 @@ def value(self) -> bytes:
def tags(self) -> list[str]:
return self._tags

@property
def user_metadata(self) -> UserMetadata:
return self._user_metadata


class Messages(Sequence[M]):
"""
Expand Down Expand Up @@ -119,6 +132,8 @@ class Datum:
event_time: the event time of the event.
watermark: the watermark of the event.
headers: the headers of the event.
user_metadata: the user metadata of the event.
system_metadata: the system metadata of the event.

Example:
```py
Expand All @@ -135,13 +150,23 @@ class Datum:
```
"""

__slots__ = ("_keys", "_value", "_event_time", "_watermark", "_headers")
__slots__ = (
"_keys",
"_value",
"_event_time",
"_watermark",
"_headers",
"_user_metadata",
"_system_metadata",
)

_keys: list[str]
_value: bytes
_event_time: datetime
_watermark: datetime
_headers: dict[str, str]
_user_metadata: UserMetadata
_system_metadata: SystemMetadata

def __init__(
self,
Expand All @@ -150,6 +175,8 @@ def __init__(
event_time: datetime,
watermark: datetime,
headers: Optional[dict[str, str]] = None,
user_metadata: Optional[UserMetadata] = None,
system_metadata: Optional[SystemMetadata] = None,
):
self._keys = keys or list()
self._value = value or b""
Expand All @@ -160,6 +187,8 @@ def __init__(
raise TypeError(f"Wrong data type: {type(watermark)} for Datum.watermark")
self._watermark = watermark
self._headers = headers or {}
self._user_metadata = user_metadata or UserMetadata()
self._system_metadata = system_metadata or SystemMetadata()

@property
def keys(self) -> list[str]:
Expand All @@ -186,6 +215,16 @@ def headers(self) -> dict[str, str]:
"""Returns the headers of the event."""
return self._headers.copy()

@property
def user_metadata(self) -> UserMetadata:
"""Returns the user metadata of the event."""
return self._user_metadata

@property
def system_metadata(self) -> SystemMetadata:
"""Returns the system metadata of the event."""
return self._system_metadata


class SourceTransformer(metaclass=ABCMeta):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from google.protobuf import timestamp_pb2 as _timestamp_pb2

from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING
from pynumaflow._metadata import _user_and_system_metadata_from_proto
from pynumaflow.proto.sourcetransformer import transform_pb2, transform_pb2_grpc
from pynumaflow.shared.asynciter import NonBlockingIterator
from pynumaflow.shared.server import handle_async_error
Expand Down Expand Up @@ -105,12 +106,17 @@ async def _invoke_transform(
Invokes the user defined function.
"""
try:
user_metadata, system_metadata = _user_and_system_metadata_from_proto(
request.request.metadata
)
datum = Datum(
keys=list(request.request.keys),
value=request.request.value,
event_time=request.request.event_time.ToDatetime(),
watermark=request.request.watermark.ToDatetime(),
headers=dict(request.request.headers),
user_metadata=user_metadata,
system_metadata=system_metadata,
)
msgs = await self.__transform_handler(list(request.request.keys), datum)
results = []
Expand All @@ -123,6 +129,7 @@ async def _invoke_transform(
value=msg.value,
tags=msg.tags,
event_time=event_time_timestamp,
metadata=msg.user_metadata._to_proto(),
)
)
await result_queue.put(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pynumaflow.shared.synciter import SyncIterator
from pynumaflow.sourcetransformer import Datum
from pynumaflow.sourcetransformer._dtypes import SourceTransformCallable
from pynumaflow._metadata import _user_and_system_metadata_from_proto
from pynumaflow.proto.sourcetransformer import transform_pb2
from pynumaflow.proto.sourcetransformer import transform_pb2_grpc
from pynumaflow.types import NumaflowServicerContext
Expand Down Expand Up @@ -119,12 +120,17 @@ def _invoke_transformer(
self, context, request: transform_pb2.SourceTransformRequest, result_queue: SyncIterator
):
try:
user_metadata, system_metadata = _user_and_system_metadata_from_proto(
request.request.metadata
)
d = Datum(
keys=list(request.request.keys),
value=request.request.value,
event_time=request.request.event_time.ToDatetime(),
watermark=request.request.watermark.ToDatetime(),
headers=dict(request.request.headers),
user_metadata=user_metadata,
system_metadata=system_metadata,
)
responses = self.__transform_handler(list(request.request.keys), d)

Expand All @@ -138,6 +144,7 @@ def _invoke_transformer(
value=resp.value,
tags=resp.tags,
event_time=event_time_timestamp,
metadata=resp.user_metadata._to_proto(),
)
)
result_queue.put(
Expand Down
117 changes: 117 additions & 0 deletions packages/pynumaflow/tests/sourcetransform/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from pynumaflow import setup_logging
from pynumaflow._constants import MAX_MESSAGE_SIZE
from pynumaflow.proto.common import metadata_pb2
from pynumaflow.proto.sourcetransformer import transform_pb2_grpc
from pynumaflow.sourcetransformer import Datum, Messages, Message, SourceTransformer
from pynumaflow.sourcetransformer.async_server import SourceTransformAsyncServer
Expand Down Expand Up @@ -267,6 +268,122 @@ def test_max_threads(self):
self.assertEqual(server.max_threads, 4)


class MetadataAsyncSourceTransformer(SourceTransformer):
"""Source transformer that validates and passes through metadata."""

async def handler(self, keys: list[str], datum: Datum) -> Messages:
# Validate system metadata
if datum.system_metadata.value("numaflow_version_info", "version") != b"1.0.0":
raise ValueError("System metadata version mismatch")

val = datum.value
msg = "payload:{} event_time:{} ".format(
val.decode("utf-8"),
datum.event_time,
)
val = bytes(msg, encoding="utf-8")
messages = Messages()
# Pass user metadata to the output message
messages.append(
Message(val, mock_new_event_time(), keys=keys, user_metadata=datum.user_metadata)
)
return messages


_metadata_s: Server = None
_metadata_channel = grpc.insecure_channel("unix:///tmp/async_st_metadata.sock")
_metadata_loop = None


def metadata_startup_callable(loop):
asyncio.set_event_loop(loop)
loop.run_forever()


def new_metadata_async_st():
handle = MetadataAsyncSourceTransformer()
server = SourceTransformAsyncServer(source_transform_instance=handle)
return server.servicer


async def start_metadata_server(udfs):
_server_options = [
("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
]
server = grpc.aio.server(options=_server_options)
transform_pb2_grpc.add_SourceTransformServicer_to_server(udfs, server)
listen_addr = "unix:///tmp/async_st_metadata.sock"
server.add_insecure_port(listen_addr)
logging.info("Starting metadata server on %s", listen_addr)
global _metadata_s
_metadata_s = server
await server.start()
await server.wait_for_termination()


@patch("psutil.Process.kill", mock_terminate_on_stop)
class TestAsyncTransformerMetadata(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
global _metadata_loop
loop = asyncio.new_event_loop()
_metadata_loop = loop
_thread = threading.Thread(target=metadata_startup_callable, args=(loop,), daemon=True)
_thread.start()
udfs = new_metadata_async_st()
asyncio.run_coroutine_threadsafe(start_metadata_server(udfs), loop=loop)
while True:
try:
with grpc.insecure_channel("unix:///tmp/async_st_metadata.sock") as channel:
f = grpc.channel_ready_future(channel)
f.result(timeout=10)
if f.done():
break
except grpc.FutureTimeoutError as e:
LOGGER.error("error trying to connect to grpc server")
LOGGER.error(e)

@classmethod
def tearDownClass(cls) -> None:
try:
_metadata_loop.stop()
LOGGER.info("stopped the metadata event loop")
except Exception as e:
LOGGER.error(e)

def test_source_transformer_with_metadata(self) -> None:
stub = transform_pb2_grpc.SourceTransformStub(_metadata_channel)
request = get_test_datums(with_metadata=True)
generator_response = None
try:
generator_response = stub.SourceTransformFn(request_iterator=request_generator(request))
except grpc.RpcError as e:
logging.error(e)
raise

responses = []
for r in generator_response:
responses.append(r)

# 1 handshake + 3 data responses
self.assertEqual(4, len(responses))
self.assertTrue(responses[0].handshake.sot)

# Verify metadata is passed through correctly
for idx, resp in enumerate(responses[1:], 1):
_id = "test-id-" + str(idx)
self.assertEqual(_id, resp.id)
self.assertEqual(1, len(resp.results))
# Verify user metadata is returned
self.assertEqual(
resp.results[0].metadata.user_metadata["custom_info"],
metadata_pb2.KeyValueGroup(key_value={"version": f"{idx}.0.0".encode()}),
)
# System metadata should be empty in responses (user cannot set it)
self.assertEqual(resp.results[0].metadata.sys_metadata, {})


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
Loading