Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions google/cloud/storage/_experimental/asyncio/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,9 @@ def raise_if_no_fast_crc32c():
"C extension is required for faster data integrity checks."
"For more information, see https://github.com/googleapis/python-crc32c."
)


def update_write_handle_if_exists(obj, response):
"""Update the write_handle attribute of an object if it exists in the response."""
if hasattr(response, "write_handle") and response.write_handle is not None:
obj.write_handle = response.write_handle
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from google_crc32c import Checksum
from google.api_core import exceptions

from ._utils import raise_if_no_fast_crc32c
from . import _utils
from google.cloud import _storage_v2
from google.cloud.storage._experimental.asyncio.async_grpc_client import (
AsyncGrpcClient,
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
servers. Default is `_DEFAULT_FLUSH_INTERVAL_BYTES`.
Must be a multiple of `_MAX_CHUNK_SIZE_BYTES`.
"""
raise_if_no_fast_crc32c()
_utils.raise_if_no_fast_crc32c()
self.client = client
self.bucket_name = bucket_name
self.object_name = object_name
Expand Down Expand Up @@ -175,6 +175,7 @@ async def state_lookup(self) -> int:
)
)
response = await self.write_obj_stream.recv()
_utils.update_write_handle_if_exists(self, response)
self.persisted_size = response.persisted_size
return self.persisted_size

Expand Down Expand Up @@ -253,6 +254,7 @@ async def append(self, data: bytes) -> None:

if is_last_chunk:
response = await self.write_obj_stream.recv()
_utils.update_write_handle_if_exists(self, response)
self.persisted_size = response.persisted_size
self.offset = self.persisted_size
self.bytes_appended_since_last_flush = 0
Expand Down Expand Up @@ -295,6 +297,7 @@ async def flush(self) -> int:
)
)
response = await self.write_obj_stream.recv()
_utils.update_write_handle_if_exists(self, response)
self.persisted_size = response.persisted_size
self.offset = self.persisted_size
return self.persisted_size
Expand Down Expand Up @@ -351,6 +354,7 @@ async def finalize(self) -> _storage_v2.Object:
_storage_v2.BidiWriteObjectRequest(finish_write=True)
)
response = await self.write_obj_stream.recv()
_utils.update_write_handle_if_exists(self, response)
self.object_resource = response.resource
self.persisted_size = self.object_resource.size
await self.write_obj_stream.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

"""
from typing import Optional
from . import _utils
from google.cloud import _storage_v2
from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient
from google.cloud.storage._experimental.asyncio.async_abstract_object_stream import (
Expand Down Expand Up @@ -190,7 +191,7 @@ async def requests_done(self):
"""Signals that all requests have been sent."""

await self.socket_like_rpc.send(None)
await self.socket_like_rpc.recv()
_utils.update_write_handle_if_exists(self, await self.socket_like_rpc.recv())

async def send(
self, bidi_write_object_request: _storage_v2.BidiWriteObjectRequest
Expand Down Expand Up @@ -218,7 +219,9 @@ async def recv(self) -> _storage_v2.BidiWriteObjectResponse:
"""
if not self._is_stream_open:
raise ValueError("Stream is not open")
return await self.socket_like_rpc.recv()
response = await self.socket_like_rpc.recv()
_utils.update_write_handle_if_exists(self, response)
return response

@property
def is_stream_open(self) -> bool:
Expand Down
102 changes: 99 additions & 3 deletions tests/unit/asyncio/test_async_write_object_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
BUCKET = "my-bucket"
OBJECT = "my-object"
GENERATION = 12345
WRITE_HANDLE = b"test-handle"
WRITE_HANDLE_PROTO = _storage_v2.BidiWriteHandle(handle=WRITE_HANDLE)
WRITE_HANDLE_BYTES = b"test-handle"
NEW_WRITE_HANDLE_BYTES = b"new-test-handle"
WRITE_HANDLE_PROTO = _storage_v2.BidiWriteHandle(handle=WRITE_HANDLE_BYTES)
NEW_WRITE_HANDLE_PROTO = _storage_v2.BidiWriteHandle(handle=NEW_WRITE_HANDLE_BYTES)


@pytest.fixture
Expand Down Expand Up @@ -151,7 +153,9 @@ async def test_open_for_new_object(mock_async_bidi_rpc, mock_client):
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
)
async def test_open_for_new_object_with_generation_zero(mock_async_bidi_rpc, mock_client):
async def test_open_for_new_object_with_generation_zero(
mock_async_bidi_rpc, mock_client
):
"""Test opening a stream for a new object."""
# Arrange
socket_like_rpc = mock.AsyncMock()
Expand Down Expand Up @@ -487,3 +491,95 @@ async def test_requests_done(mock_cls_async_bidi_rpc, mock_client):
# Assert
write_obj_stream.socket_like_rpc.send.assert_called_once_with(None)
write_obj_stream.socket_like_rpc.recv.assert_called_once()


@pytest.mark.asyncio
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
)
async def test_open_for_existing_object_with_none_size(
mock_async_bidi_rpc, mock_client
):
"""Test opening a stream for an existing object where size is None."""
# Arrange
socket_like_rpc = mock.AsyncMock()
mock_async_bidi_rpc.return_value = socket_like_rpc
socket_like_rpc.open = mock.AsyncMock()

mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse)
mock_response.resource = mock.MagicMock(spec=_storage_v2.Object)
mock_response.resource.size = None
mock_response.resource.generation = GENERATION
mock_response.write_handle = WRITE_HANDLE_PROTO
socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response)

stream = _AsyncWriteObjectStream(
mock_client, BUCKET, OBJECT, generation_number=GENERATION
)

# Act
await stream.open()

# Assert
assert stream.persisted_size == 0


@pytest.mark.asyncio
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
)
async def test_recv_updates_write_handle(mock_cls_async_bidi_rpc, mock_client):
"""Test that recv updates the write_handle if present in the response."""
# Arrange
write_obj_stream = await instantiate_write_obj_stream(
mock_client, mock_cls_async_bidi_rpc, open=True
)

assert write_obj_stream.write_handle == WRITE_HANDLE_PROTO # Initial handle

# GCS can periodicallly update write handle in their responses.
bidi_write_object_response = _storage_v2.BidiWriteObjectResponse(
write_handle=NEW_WRITE_HANDLE_PROTO
)
write_obj_stream.socket_like_rpc.recv = AsyncMock(
return_value=bidi_write_object_response
)

# Act
response = await write_obj_stream.recv()

# Assert
write_obj_stream.socket_like_rpc.recv.assert_called_once()
assert response == bidi_write_object_response
# asserts that new write handle has been updated.
assert write_obj_stream.write_handle == NEW_WRITE_HANDLE_PROTO


@pytest.mark.asyncio
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
)
async def test_requests_done_updates_write_handle(mock_cls_async_bidi_rpc, mock_client):
"""Test that requests_done updates the write_handle if present in the response."""
# Arrange
write_obj_stream = await instantiate_write_obj_stream(
mock_client, mock_cls_async_bidi_rpc, open=True
)
assert write_obj_stream.write_handle == WRITE_HANDLE_PROTO # Initial handle

# new_write_handle = b"new-test-handle"
bidi_write_object_response = _storage_v2.BidiWriteObjectResponse(
write_handle=NEW_WRITE_HANDLE_PROTO
)
write_obj_stream.socket_like_rpc.send = AsyncMock()
write_obj_stream.socket_like_rpc.recv = AsyncMock(
return_value=bidi_write_object_response
)

# Act
await write_obj_stream.requests_done()

# Assert
write_obj_stream.socket_like_rpc.send.assert_called_once_with(None)
write_obj_stream.socket_like_rpc.recv.assert_called_once()
assert write_obj_stream.write_handle == NEW_WRITE_HANDLE_PROTO