diff --git a/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py b/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py index 56aba3e27..16dce5025 100644 --- a/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py +++ b/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import asyncio import google_crc32c from google.api_core import exceptions from google_crc32c import Checksum @@ -29,6 +30,7 @@ from io import BytesIO from google.cloud import _storage_v2 from google.cloud.storage.exceptions import DataCorruption +from google.cloud.storage._helpers import generate_random_56_bit_integer _MAX_READ_RANGES_PER_BIDI_READ_REQUEST = 100 @@ -78,7 +80,7 @@ class AsyncMultiRangeDownloader: my_buff2 = BytesIO() my_buff3 = BytesIO() my_buff4 = any_object_which_provides_BytesIO_like_interface() - results_arr = await mrd.download_ranges( + await mrd.download_ranges( [ # (start_byte, bytes_to_read, writeable_buffer) (0, 100, my_buff1), @@ -88,8 +90,8 @@ class AsyncMultiRangeDownloader: ] ) - for result in results_arr: - print("downloaded bytes", result) + # verify data in buffers... + assert my_buff2.getbuffer().nbytes == 20 """ @@ -175,6 +177,10 @@ def __init__( self.read_obj_str: Optional[_AsyncReadObjectStream] = None self._is_stream_open: bool = False + self._read_id_to_writable_buffer_dict = {} + self._read_id_to_download_ranges_id = {} + self._download_ranges_id_to_pending_read_ids = {} + async def open(self) -> None: """Opens the bidi-gRPC connection to read from the object. @@ -203,8 +209,8 @@ async def open(self) -> None: return async def download_ranges( - self, read_ranges: List[Tuple[int, int, BytesIO]] - ) -> List[Result]: + self, read_ranges: List[Tuple[int, int, BytesIO]], lock: asyncio.Lock = None + ) -> None: """Downloads multiple byte ranges from the object into the buffers provided by user. @@ -214,9 +220,36 @@ async def download_ranges( to be provided by the user, and user has to make sure appropriate memory is available in the application to avoid out-of-memory crash. - :rtype: List[:class:`~google.cloud.storage._experimental.asyncio.async_multi_range_downloader.Result`] - :returns: A list of ``Result`` objects, where each object corresponds - to a requested range. + :type lock: asyncio.Lock + :param lock: (Optional) An asyncio lock to synchronize sends and recvs + on the underlying bidi-GRPC stream. This is required when multiple + coroutines are calling this method concurrently. + + i.e. Example usage with multiple coroutines: + + ``` + lock = asyncio.Lock() + task1 = asyncio.create_task(mrd.download_ranges(ranges1, lock)) + task2 = asyncio.create_task(mrd.download_ranges(ranges2, lock)) + await asyncio.gather(task1, task2) + + ``` + + If user want to call this method serially from multiple coroutines, + then providing a lock is not necessary. + + ``` + await mrd.download_ranges(ranges1) + await mrd.download_ranges(ranges2) + + # ... some other code code... + + ``` + + + :raises ValueError: if the underlying bidi-GRPC stream is not open. + :raises ValueError: if the length of read_ranges is more than 1000. + :raises DataCorruption: if a checksum mismatch is detected while reading data. """ @@ -228,8 +261,11 @@ async def download_ranges( if not self._is_stream_open: raise ValueError("Underlying bidi-gRPC stream is not open") - read_id_to_writable_buffer_dict = {} - results = [] + if lock is None: + lock = asyncio.Lock() + + _func_id = generate_random_56_bit_integer() + read_ids_in_current_func = set() for i in range(0, len(read_ranges), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST): read_ranges_segment = read_ranges[ i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST @@ -237,10 +273,11 @@ async def download_ranges( read_ranges_for_bidi_req = [] for j, read_range in enumerate(read_ranges_segment): - read_id = i + j - read_id_to_writable_buffer_dict[read_id] = read_range[2] + read_id = generate_random_56_bit_integer() + read_ids_in_current_func.add(read_id) + self._read_id_to_download_ranges_id[read_id] = _func_id + self._read_id_to_writable_buffer_dict[read_id] = read_range[2] bytes_requested = read_range[1] - results.append(Result(bytes_requested)) read_ranges_for_bidi_req.append( _storage_v2.ReadRange( read_offset=read_range[0], @@ -248,12 +285,19 @@ async def download_ranges( read_id=read_id, ) ) - await self.read_obj_str.send( - _storage_v2.BidiReadObjectRequest(read_ranges=read_ranges_for_bidi_req) - ) + async with lock: + await self.read_obj_str.send( + _storage_v2.BidiReadObjectRequest( + read_ranges=read_ranges_for_bidi_req + ) + ) + self._download_ranges_id_to_pending_read_ids[ + _func_id + ] = read_ids_in_current_func - while len(read_id_to_writable_buffer_dict) > 0: - response = await self.read_obj_str.recv() + while len(self._download_ranges_id_to_pending_read_ids[_func_id]) > 0: + async with lock: + response = await self.read_obj_str.recv() if response is None: raise Exception("None response received, something went wrong.") @@ -277,16 +321,15 @@ async def download_ranges( ) read_id = object_data_range.read_range.read_id - buffer = read_id_to_writable_buffer_dict[read_id] + buffer = self._read_id_to_writable_buffer_dict[read_id] buffer.write(data) - results[read_id].bytes_written += len(data) if object_data_range.range_end: - del read_id_to_writable_buffer_dict[ - object_data_range.read_range.read_id - ] - - return results + tmp_dn_ranges_id = self._read_id_to_download_ranges_id[read_id] + self._download_ranges_id_to_pending_read_ids[ + tmp_dn_ranges_id + ].remove(read_id) + del self._read_id_to_download_ranges_id[read_id] async def close(self): """ diff --git a/google/cloud/storage/_helpers.py b/google/cloud/storage/_helpers.py index 236480a7e..682f8784d 100644 --- a/google/cloud/storage/_helpers.py +++ b/google/cloud/storage/_helpers.py @@ -22,6 +22,7 @@ from hashlib import md5 import os import sys +import secrets from urllib.parse import urlsplit from urllib.parse import urlunsplit from uuid import uuid4 @@ -668,3 +669,20 @@ def _get_default_headers( "content-type": content_type, "x-upload-content-type": x_upload_content_type or content_type, } + + +def generate_random_56_bit_integer(): + """Generates a secure 56 bit random integer. + + + If 64 bit int is used, sometimes the random int generated is greater than + max positive value of signed 64 bit int which is 2^63 -1 causing overflow + issues. + + :rtype: int + :returns: A secure random 56 bit integer. + """ + # 7 bytes * 8 bits/byte = 56 bits + random_bytes = secrets.token_bytes(7) + # Convert bytes to an integer + return int.from_bytes(random_bytes, "big") diff --git a/tests/unit/asyncio/test_async_multi_range_downloader.py b/tests/unit/asyncio/test_async_multi_range_downloader.py index 8c1137980..668006627 100644 --- a/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import pytest from unittest import mock from unittest.mock import AsyncMock @@ -107,6 +108,93 @@ async def test_create_mrd( assert mrd.read_handle == _TEST_READ_HANDLE assert mrd.is_stream_open + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_multi_range_downloader.generate_random_56_bit_integer" + ) + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream" + ) + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + ) + @pytest.mark.asyncio + async def test_download_ranges_via_async_gather( + self, mock_grpc_client, mock_cls_async_read_object_stream, mock_random_int + ): + # Arrange + data = b"these_are_18_chars" + crc32c = Checksum(data).digest() + crc32c_int = int.from_bytes(crc32c, "big") + crc32c_checksum_for_data_slice = int.from_bytes( + Checksum(data[10:16]).digest(), "big" + ) + + mock_mrd = await self._make_mock_mrd( + mock_grpc_client, mock_cls_async_read_object_stream + ) + mock_random_int.side_effect = [123, 456, 789, 91011] # for _func_id and read_id + mock_mrd.read_obj_str.send = AsyncMock() + mock_mrd.read_obj_str.recv = AsyncMock() + + mock_mrd.read_obj_str.recv.side_effect = [ + _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data, crc32c=crc32c_int + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=18, read_id=456 + ), + ) + ] + ), + _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data[10:16], + crc32c=crc32c_checksum_for_data_slice, + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=10, read_length=6, read_id=91011 + ), + ) + ], + ), + ] + + # Act + buffer = BytesIO() + second_buffer = BytesIO() + lock = asyncio.Lock() + task1 = asyncio.create_task(mock_mrd.download_ranges([(0, 18, buffer)], lock)) + task2 = asyncio.create_task( + mock_mrd.download_ranges([(10, 6, second_buffer)], lock) + ) + await asyncio.gather(task1, task2) + + # Assert + mock_mrd.read_obj_str.send.side_effect = [ + _storage_v2.BidiReadObjectRequest( + read_ranges=[ + _storage_v2.ReadRange(read_offset=0, read_length=18, read_id=456) + ] + ), + _storage_v2.BidiReadObjectRequest( + read_ranges=[ + _storage_v2.ReadRange(read_offset=10, read_length=6, read_id=91011) + ] + ), + ] + assert buffer.getvalue() == data + assert second_buffer.getvalue() == data[10:16] + + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_multi_range_downloader.generate_random_56_bit_integer" + ) @mock.patch( "google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream" ) @@ -115,7 +203,7 @@ async def test_create_mrd( ) @pytest.mark.asyncio async def test_download_ranges( - self, mock_grpc_client, mock_cls_async_read_object_stream + self, mock_grpc_client, mock_cls_async_read_object_stream, mock_random_int ): # Arrange data = b"these_are_18_chars" @@ -125,6 +213,7 @@ async def test_download_ranges( mock_mrd = await self._make_mock_mrd( mock_grpc_client, mock_cls_async_read_object_stream ) + mock_random_int.side_effect = [123, 456] # for _func_id and read_id mock_mrd.read_obj_str.send = AsyncMock() mock_mrd.read_obj_str.recv = AsyncMock() mock_mrd.read_obj_str.recv.return_value = _storage_v2.BidiReadObjectResponse( @@ -135,7 +224,7 @@ async def test_download_ranges( ), range_end=True, read_range=_storage_v2.ReadRange( - read_offset=0, read_length=18, read_id=0 + read_offset=0, read_length=18, read_id=456 ), ) ], @@ -143,19 +232,16 @@ async def test_download_ranges( # Act buffer = BytesIO() - results = await mock_mrd.download_ranges([(0, 18, buffer)]) + await mock_mrd.download_ranges([(0, 18, buffer)]) # Assert mock_mrd.read_obj_str.send.assert_called_once_with( _storage_v2.BidiReadObjectRequest( read_ranges=[ - _storage_v2.ReadRange(read_offset=0, read_length=18, read_id=0) + _storage_v2.ReadRange(read_offset=0, read_length=18, read_id=456) ] ) ) - assert len(results) == 1 - assert results[0].bytes_requested == 18 - assert results[0].bytes_written == 18 assert buffer.getvalue() == data @mock.patch(