Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from __future__ import annotations
import asyncio
import google_crc32c
from google.api_core import exceptions
from google_crc32c import Checksum
Expand All @@ -29,6 +30,7 @@
from io import BytesIO
from google.cloud import _storage_v2
from google.cloud.storage.exceptions import DataCorruption
from google.cloud.storage._helpers import generate_random_56_bit_integer


_MAX_READ_RANGES_PER_BIDI_READ_REQUEST = 100
Expand Down Expand Up @@ -78,7 +80,7 @@ class AsyncMultiRangeDownloader:
my_buff2 = BytesIO()
my_buff3 = BytesIO()
my_buff4 = any_object_which_provides_BytesIO_like_interface()
results_arr = await mrd.download_ranges(
await mrd.download_ranges(
[
# (start_byte, bytes_to_read, writeable_buffer)
(0, 100, my_buff1),
Expand All @@ -88,8 +90,8 @@ class AsyncMultiRangeDownloader:
]
)

for result in results_arr:
print("downloaded bytes", result)
# verify data in buffers...
assert my_buff2.getbuffer().nbytes == 20


"""
Expand Down Expand Up @@ -175,6 +177,10 @@ def __init__(
self.read_obj_str: Optional[_AsyncReadObjectStream] = None
self._is_stream_open: bool = False

self._read_id_to_writable_buffer_dict = {}
self._read_id_to_download_ranges_id = {}
self._download_ranges_id_to_pending_read_ids = {}

async def open(self) -> None:
"""Opens the bidi-gRPC connection to read from the object.

Expand Down Expand Up @@ -203,8 +209,8 @@ async def open(self) -> None:
return

async def download_ranges(
self, read_ranges: List[Tuple[int, int, BytesIO]]
) -> List[Result]:
self, read_ranges: List[Tuple[int, int, BytesIO]], lock: asyncio.Lock = None
) -> None:
"""Downloads multiple byte ranges from the object into the buffers
provided by user.

Expand All @@ -214,9 +220,36 @@ async def download_ranges(
to be provided by the user, and user has to make sure appropriate
memory is available in the application to avoid out-of-memory crash.

:rtype: List[:class:`~google.cloud.storage._experimental.asyncio.async_multi_range_downloader.Result`]
:returns: A list of ``Result`` objects, where each object corresponds
to a requested range.
:type lock: asyncio.Lock
:param lock: (Optional) An asyncio lock to synchronize sends and recvs
on the underlying bidi-GRPC stream. This is required when multiple
coroutines are calling this method concurrently.

i.e. Example usage with multiple coroutines:

```
lock = asyncio.Lock()
task1 = asyncio.create_task(mrd.download_ranges(ranges1, lock))
task2 = asyncio.create_task(mrd.download_ranges(ranges2, lock))
await asyncio.gather(task1, task2)

```

If user want to call this method serially from multiple coroutines,
then providing a lock is not necessary.

```
await mrd.download_ranges(ranges1)
await mrd.download_ranges(ranges2)

# ... some other code code...

```


:raises ValueError: if the underlying bidi-GRPC stream is not open.
:raises ValueError: if the length of read_ranges is more than 1000.
:raises DataCorruption: if a checksum mismatch is detected while reading data.

"""

Expand All @@ -228,32 +261,43 @@ async def download_ranges(
if not self._is_stream_open:
raise ValueError("Underlying bidi-gRPC stream is not open")

read_id_to_writable_buffer_dict = {}
results = []
if lock is None:
lock = asyncio.Lock()

_func_id = generate_random_56_bit_integer()
read_ids_in_current_func = set()
for i in range(0, len(read_ranges), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST):
read_ranges_segment = read_ranges[
i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
]

read_ranges_for_bidi_req = []
for j, read_range in enumerate(read_ranges_segment):
read_id = i + j
read_id_to_writable_buffer_dict[read_id] = read_range[2]
read_id = generate_random_56_bit_integer()
read_ids_in_current_func.add(read_id)
self._read_id_to_download_ranges_id[read_id] = _func_id
self._read_id_to_writable_buffer_dict[read_id] = read_range[2]
bytes_requested = read_range[1]
results.append(Result(bytes_requested))
read_ranges_for_bidi_req.append(
_storage_v2.ReadRange(
read_offset=read_range[0],
read_length=bytes_requested,
read_id=read_id,
)
)
await self.read_obj_str.send(
_storage_v2.BidiReadObjectRequest(read_ranges=read_ranges_for_bidi_req)
)
async with lock:
await self.read_obj_str.send(
_storage_v2.BidiReadObjectRequest(
read_ranges=read_ranges_for_bidi_req
)
)
self._download_ranges_id_to_pending_read_ids[
_func_id
] = read_ids_in_current_func

while len(read_id_to_writable_buffer_dict) > 0:
response = await self.read_obj_str.recv()
while len(self._download_ranges_id_to_pending_read_ids[_func_id]) > 0:
async with lock:
response = await self.read_obj_str.recv()

if response is None:
raise Exception("None response received, something went wrong.")
Expand All @@ -277,16 +321,15 @@ async def download_ranges(
)

read_id = object_data_range.read_range.read_id
buffer = read_id_to_writable_buffer_dict[read_id]
buffer = self._read_id_to_writable_buffer_dict[read_id]
buffer.write(data)
results[read_id].bytes_written += len(data)

if object_data_range.range_end:
del read_id_to_writable_buffer_dict[
object_data_range.read_range.read_id
]

return results
tmp_dn_ranges_id = self._read_id_to_download_ranges_id[read_id]
self._download_ranges_id_to_pending_read_ids[
tmp_dn_ranges_id
].remove(read_id)
del self._read_id_to_download_ranges_id[read_id]

async def close(self):
"""
Expand Down
18 changes: 18 additions & 0 deletions google/cloud/storage/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from hashlib import md5
import os
import sys
import secrets
from urllib.parse import urlsplit
from urllib.parse import urlunsplit
from uuid import uuid4
Expand Down Expand Up @@ -668,3 +669,20 @@ def _get_default_headers(
"content-type": content_type,
"x-upload-content-type": x_upload_content_type or content_type,
}


def generate_random_56_bit_integer():
"""Generates a secure 56 bit random integer.


If 64 bit int is used, sometimes the random int generated is greater than
max positive value of signed 64 bit int which is 2^63 -1 causing overflow
issues.

:rtype: int
:returns: A secure random 56 bit integer.
"""
# 7 bytes * 8 bits/byte = 56 bits
random_bytes = secrets.token_bytes(7)
# Convert bytes to an integer
return int.from_bytes(random_bytes, "big")
100 changes: 93 additions & 7 deletions tests/unit/asyncio/test_async_multi_range_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import pytest
from unittest import mock
from unittest.mock import AsyncMock
Expand Down Expand Up @@ -107,6 +108,93 @@ async def test_create_mrd(
assert mrd.read_handle == _TEST_READ_HANDLE
assert mrd.is_stream_open

@mock.patch(
"google.cloud.storage._experimental.asyncio.async_multi_range_downloader.generate_random_56_bit_integer"
)
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream"
)
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client"
)
@pytest.mark.asyncio
async def test_download_ranges_via_async_gather(
self, mock_grpc_client, mock_cls_async_read_object_stream, mock_random_int
):
# Arrange
data = b"these_are_18_chars"
crc32c = Checksum(data).digest()
crc32c_int = int.from_bytes(crc32c, "big")
crc32c_checksum_for_data_slice = int.from_bytes(
Checksum(data[10:16]).digest(), "big"
)

mock_mrd = await self._make_mock_mrd(
mock_grpc_client, mock_cls_async_read_object_stream
)
mock_random_int.side_effect = [123, 456, 789, 91011] # for _func_id and read_id
mock_mrd.read_obj_str.send = AsyncMock()
mock_mrd.read_obj_str.recv = AsyncMock()

mock_mrd.read_obj_str.recv.side_effect = [
_storage_v2.BidiReadObjectResponse(
object_data_ranges=[
_storage_v2.ObjectRangeData(
checksummed_data=_storage_v2.ChecksummedData(
content=data, crc32c=crc32c_int
),
range_end=True,
read_range=_storage_v2.ReadRange(
read_offset=0, read_length=18, read_id=456
),
)
]
),
_storage_v2.BidiReadObjectResponse(
object_data_ranges=[
_storage_v2.ObjectRangeData(
checksummed_data=_storage_v2.ChecksummedData(
content=data[10:16],
crc32c=crc32c_checksum_for_data_slice,
),
range_end=True,
read_range=_storage_v2.ReadRange(
read_offset=10, read_length=6, read_id=91011
),
)
],
),
]

# Act
buffer = BytesIO()
second_buffer = BytesIO()
lock = asyncio.Lock()
task1 = asyncio.create_task(mock_mrd.download_ranges([(0, 18, buffer)], lock))
task2 = asyncio.create_task(
mock_mrd.download_ranges([(10, 6, second_buffer)], lock)
)
await asyncio.gather(task1, task2)

# Assert
mock_mrd.read_obj_str.send.side_effect = [
_storage_v2.BidiReadObjectRequest(
read_ranges=[
_storage_v2.ReadRange(read_offset=0, read_length=18, read_id=456)
]
),
_storage_v2.BidiReadObjectRequest(
read_ranges=[
_storage_v2.ReadRange(read_offset=10, read_length=6, read_id=91011)
]
),
]
assert buffer.getvalue() == data
assert second_buffer.getvalue() == data[10:16]

@mock.patch(
"google.cloud.storage._experimental.asyncio.async_multi_range_downloader.generate_random_56_bit_integer"
)
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream"
)
Expand All @@ -115,7 +203,7 @@ async def test_create_mrd(
)
@pytest.mark.asyncio
async def test_download_ranges(
self, mock_grpc_client, mock_cls_async_read_object_stream
self, mock_grpc_client, mock_cls_async_read_object_stream, mock_random_int
):
# Arrange
data = b"these_are_18_chars"
Expand All @@ -125,6 +213,7 @@ async def test_download_ranges(
mock_mrd = await self._make_mock_mrd(
mock_grpc_client, mock_cls_async_read_object_stream
)
mock_random_int.side_effect = [123, 456] # for _func_id and read_id
mock_mrd.read_obj_str.send = AsyncMock()
mock_mrd.read_obj_str.recv = AsyncMock()
mock_mrd.read_obj_str.recv.return_value = _storage_v2.BidiReadObjectResponse(
Expand All @@ -135,27 +224,24 @@ async def test_download_ranges(
),
range_end=True,
read_range=_storage_v2.ReadRange(
read_offset=0, read_length=18, read_id=0
read_offset=0, read_length=18, read_id=456
),
)
],
)

# Act
buffer = BytesIO()
results = await mock_mrd.download_ranges([(0, 18, buffer)])
await mock_mrd.download_ranges([(0, 18, buffer)])

# Assert
mock_mrd.read_obj_str.send.assert_called_once_with(
_storage_v2.BidiReadObjectRequest(
read_ranges=[
_storage_v2.ReadRange(read_offset=0, read_length=18, read_id=0)
_storage_v2.ReadRange(read_offset=0, read_length=18, read_id=456)
]
)
)
assert len(results) == 1
assert results[0].bytes_requested == 18
assert results[0].bytes_written == 18
assert buffer.getvalue() == data

@mock.patch(
Expand Down