diff --git a/google/cloud/storage/_experimental/asyncio/_utils.py b/google/cloud/storage/_experimental/asyncio/_utils.py index 32d83a586..170a0cfae 100644 --- a/google/cloud/storage/_experimental/asyncio/_utils.py +++ b/google/cloud/storage/_experimental/asyncio/_utils.py @@ -33,3 +33,9 @@ def raise_if_no_fast_crc32c(): "C extension is required for faster data integrity checks." "For more information, see https://github.com/googleapis/python-crc32c." ) + + +def update_write_handle_if_exists(obj, response): + """Update the write_handle attribute of an object if it exists in the response.""" + if hasattr(response, "write_handle") and response.write_handle is not None: + obj.write_handle = response.write_handle diff --git a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py index c961fbefb..eda21019d 100644 --- a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py +++ b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py @@ -27,7 +27,7 @@ from google_crc32c import Checksum from google.api_core import exceptions -from ._utils import raise_if_no_fast_crc32c +from . import _utils from google.cloud import _storage_v2 from google.cloud.storage._experimental.asyncio.async_grpc_client import ( AsyncGrpcClient, @@ -121,7 +121,7 @@ def __init__( servers. Default is `_DEFAULT_FLUSH_INTERVAL_BYTES`. Must be a multiple of `_MAX_CHUNK_SIZE_BYTES`. """ - raise_if_no_fast_crc32c() + _utils.raise_if_no_fast_crc32c() self.client = client self.bucket_name = bucket_name self.object_name = object_name @@ -175,6 +175,7 @@ async def state_lookup(self) -> int: ) ) response = await self.write_obj_stream.recv() + _utils.update_write_handle_if_exists(self, response) self.persisted_size = response.persisted_size return self.persisted_size @@ -253,6 +254,7 @@ async def append(self, data: bytes) -> None: if is_last_chunk: response = await self.write_obj_stream.recv() + _utils.update_write_handle_if_exists(self, response) self.persisted_size = response.persisted_size self.offset = self.persisted_size self.bytes_appended_since_last_flush = 0 @@ -295,6 +297,7 @@ async def flush(self) -> int: ) ) response = await self.write_obj_stream.recv() + _utils.update_write_handle_if_exists(self, response) self.persisted_size = response.persisted_size self.offset = self.persisted_size return self.persisted_size @@ -351,6 +354,7 @@ async def finalize(self) -> _storage_v2.Object: _storage_v2.BidiWriteObjectRequest(finish_write=True) ) response = await self.write_obj_stream.recv() + _utils.update_write_handle_if_exists(self, response) self.object_resource = response.resource self.persisted_size = self.object_resource.size await self.write_obj_stream.close() diff --git a/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py b/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py index 731b18e45..682438dea 100644 --- a/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py +++ b/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py @@ -22,6 +22,7 @@ """ from typing import Optional +from . import _utils from google.cloud import _storage_v2 from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient from google.cloud.storage._experimental.asyncio.async_abstract_object_stream import ( @@ -190,7 +191,7 @@ async def requests_done(self): """Signals that all requests have been sent.""" await self.socket_like_rpc.send(None) - await self.socket_like_rpc.recv() + _utils.update_write_handle_if_exists(self, await self.socket_like_rpc.recv()) async def send( self, bidi_write_object_request: _storage_v2.BidiWriteObjectRequest @@ -218,7 +219,9 @@ async def recv(self) -> _storage_v2.BidiWriteObjectResponse: """ if not self._is_stream_open: raise ValueError("Stream is not open") - return await self.socket_like_rpc.recv() + response = await self.socket_like_rpc.recv() + _utils.update_write_handle_if_exists(self, response) + return response @property def is_stream_open(self) -> bool: diff --git a/tests/unit/asyncio/test_async_write_object_stream.py b/tests/unit/asyncio/test_async_write_object_stream.py index 619d5f7e6..92ba2925a 100644 --- a/tests/unit/asyncio/test_async_write_object_stream.py +++ b/tests/unit/asyncio/test_async_write_object_stream.py @@ -24,8 +24,10 @@ BUCKET = "my-bucket" OBJECT = "my-object" GENERATION = 12345 -WRITE_HANDLE = b"test-handle" -WRITE_HANDLE_PROTO = _storage_v2.BidiWriteHandle(handle=WRITE_HANDLE) +WRITE_HANDLE_BYTES = b"test-handle" +NEW_WRITE_HANDLE_BYTES = b"new-test-handle" +WRITE_HANDLE_PROTO = _storage_v2.BidiWriteHandle(handle=WRITE_HANDLE_BYTES) +NEW_WRITE_HANDLE_PROTO = _storage_v2.BidiWriteHandle(handle=NEW_WRITE_HANDLE_BYTES) @pytest.fixture @@ -151,7 +153,9 @@ async def test_open_for_new_object(mock_async_bidi_rpc, mock_client): @mock.patch( "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" ) -async def test_open_for_new_object_with_generation_zero(mock_async_bidi_rpc, mock_client): +async def test_open_for_new_object_with_generation_zero( + mock_async_bidi_rpc, mock_client +): """Test opening a stream for a new object.""" # Arrange socket_like_rpc = mock.AsyncMock() @@ -487,3 +491,95 @@ async def test_requests_done(mock_cls_async_bidi_rpc, mock_client): # Assert write_obj_stream.socket_like_rpc.send.assert_called_once_with(None) write_obj_stream.socket_like_rpc.recv.assert_called_once() + + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +) +async def test_open_for_existing_object_with_none_size( + mock_async_bidi_rpc, mock_client +): + """Test opening a stream for an existing object where size is None.""" + # Arrange + socket_like_rpc = mock.AsyncMock() + mock_async_bidi_rpc.return_value = socket_like_rpc + socket_like_rpc.open = mock.AsyncMock() + + mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) + mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) + mock_response.resource.size = None + mock_response.resource.generation = GENERATION + mock_response.write_handle = WRITE_HANDLE_PROTO + socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) + + stream = _AsyncWriteObjectStream( + mock_client, BUCKET, OBJECT, generation_number=GENERATION + ) + + # Act + await stream.open() + + # Assert + assert stream.persisted_size == 0 + + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +) +async def test_recv_updates_write_handle(mock_cls_async_bidi_rpc, mock_client): + """Test that recv updates the write_handle if present in the response.""" + # Arrange + write_obj_stream = await instantiate_write_obj_stream( + mock_client, mock_cls_async_bidi_rpc, open=True + ) + + assert write_obj_stream.write_handle == WRITE_HANDLE_PROTO # Initial handle + + # GCS can periodicallly update write handle in their responses. + bidi_write_object_response = _storage_v2.BidiWriteObjectResponse( + write_handle=NEW_WRITE_HANDLE_PROTO + ) + write_obj_stream.socket_like_rpc.recv = AsyncMock( + return_value=bidi_write_object_response + ) + + # Act + response = await write_obj_stream.recv() + + # Assert + write_obj_stream.socket_like_rpc.recv.assert_called_once() + assert response == bidi_write_object_response + # asserts that new write handle has been updated. + assert write_obj_stream.write_handle == NEW_WRITE_HANDLE_PROTO + + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +) +async def test_requests_done_updates_write_handle(mock_cls_async_bidi_rpc, mock_client): + """Test that requests_done updates the write_handle if present in the response.""" + # Arrange + write_obj_stream = await instantiate_write_obj_stream( + mock_client, mock_cls_async_bidi_rpc, open=True + ) + assert write_obj_stream.write_handle == WRITE_HANDLE_PROTO # Initial handle + + # new_write_handle = b"new-test-handle" + bidi_write_object_response = _storage_v2.BidiWriteObjectResponse( + write_handle=NEW_WRITE_HANDLE_PROTO + ) + write_obj_stream.socket_like_rpc.send = AsyncMock() + write_obj_stream.socket_like_rpc.recv = AsyncMock( + return_value=bidi_write_object_response + ) + + # Act + await write_obj_stream.requests_done() + + # Assert + write_obj_stream.socket_like_rpc.send.assert_called_once_with(None) + write_obj_stream.socket_like_rpc.recv.assert_called_once() + assert write_obj_stream.write_handle == NEW_WRITE_HANDLE_PROTO