Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3083.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for async vectorized and orthogonal indexing.
67 changes: 66 additions & 1 deletion src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
ZarrFormat,
_default_zarr_format,
_warn_order_kwarg,
ceildiv,
concurrent_map,
parse_shapelike,
product,
Expand All @@ -76,6 +77,8 @@
)
from zarr.core.dtype.common import HasEndianness, HasItemSize, HasObjectCodec
from zarr.core.indexing import (
AsyncOIndex,
AsyncVIndex,
BasicIndexer,
BasicSelection,
BlockIndex,
Expand All @@ -92,7 +95,6 @@
Selection,
VIndex,
_iter_grid,
ceildiv,
check_fields,
check_no_multi_fields,
is_pure_fancy_indexing,
Expand Down Expand Up @@ -1425,6 +1427,56 @@ async def getitem(
)
return await self._get_selection(indexer, prototype=prototype)

async def get_orthogonal_selection(
self,
selection: OrthogonalSelection,
*,
out: NDBuffer | None = None,
fields: Fields | None = None,
prototype: BufferPrototype | None = None,
) -> NDArrayLikeOrScalar:
if prototype is None:
prototype = default_buffer_prototype()
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
return await self._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
)

async def get_mask_selection(
self,
mask: MaskSelection,
*,
out: NDBuffer | None = None,
fields: Fields | None = None,
prototype: BufferPrototype | None = None,
) -> NDArrayLikeOrScalar:
if prototype is None:
prototype = default_buffer_prototype()
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
return await self._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
)

async def get_coordinate_selection(
self,
selection: CoordinateSelection,
*,
out: NDBuffer | None = None,
fields: Fields | None = None,
prototype: BufferPrototype | None = None,
) -> NDArrayLikeOrScalar:
if prototype is None:
prototype = default_buffer_prototype()
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
out_array = await self._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
)

if hasattr(out_array, "shape"):
# restore shape
out_array = np.array(out_array).reshape(indexer.sel_shape)
return out_array

async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None:
"""
Asynchronously save the array metadata.
Expand Down Expand Up @@ -1556,6 +1608,19 @@ async def setitem(
)
return await self._set_selection(indexer, value, prototype=prototype)

@property
def oindex(self) -> AsyncOIndex[T_ArrayMetadata]:
"""Shortcut for orthogonal (outer) indexing, see :func:`get_orthogonal_selection` and
:func:`set_orthogonal_selection` for documentation and examples."""
return AsyncOIndex(self)

@property
def vindex(self) -> AsyncVIndex[T_ArrayMetadata]:
"""Shortcut for vectorized (inner) indexing, see :func:`get_coordinate_selection`,
:func:`set_coordinate_selection`, :func:`get_mask_selection` and
:func:`set_mask_selection` for documentation and examples."""
return AsyncVIndex(self)

async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True) -> None:
"""
Asynchronously resize the array to a new shape.
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/core/chunk_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
ChunkCoords,
ChunkCoordsLike,
ShapeLike,
ceildiv,
parse_named_configuration,
parse_shapelike,
)
from zarr.core.indexing import ceildiv

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down
7 changes: 7 additions & 0 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import functools
import math
import operator
import warnings
from collections.abc import Iterable, Mapping, Sequence
Expand Down Expand Up @@ -69,6 +70,12 @@ def product(tup: ChunkCoords) -> int:
return functools.reduce(operator.mul, tup, 1)


def ceildiv(a: float, b: float) -> int:
if a == 0:
return 0
return math.ceil(a / b)


T = TypeVar("T", bound=tuple[Any, ...])
V = TypeVar("V")

Expand Down
58 changes: 50 additions & 8 deletions src/zarr/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import (
TYPE_CHECKING,
Any,
Generic,
Literal,
NamedTuple,
Protocol,
Expand All @@ -25,14 +26,16 @@
import numpy as np
import numpy.typing as npt

from zarr.core.common import product
from zarr.core.common import ceildiv, product
from zarr.core.metadata import T_ArrayMetadata

if TYPE_CHECKING:
from zarr.core.array import Array
from zarr.core.array import Array, AsyncArray
from zarr.core.buffer import NDArrayLikeOrScalar
from zarr.core.chunk_grids import ChunkGrid
from zarr.core.common import ChunkCoords


IntSequence = list[int] | npt.NDArray[np.intp]
ArrayOfIntOrBool = npt.NDArray[np.intp] | npt.NDArray[np.bool_]
BasicSelector = int | slice | EllipsisType
Expand Down Expand Up @@ -93,12 +96,6 @@ class Indexer(Protocol):
def __iter__(self) -> Iterator[ChunkProjection]: ...


def ceildiv(a: float, b: float) -> int:
if a == 0:
return 0
return math.ceil(a / b)


_ArrayIndexingOrder: TypeAlias = Literal["lexicographic"]


Expand Down Expand Up @@ -960,6 +957,25 @@ def __setitem__(self, selection: OrthogonalSelection, value: npt.ArrayLike) -> N
)


@dataclass(frozen=True)
class AsyncOIndex(Generic[T_ArrayMetadata]):
array: AsyncArray[T_ArrayMetadata]

async def getitem(self, selection: OrthogonalSelection | Array) -> NDArrayLikeOrScalar:
from zarr.core.array import Array

# if input is a Zarr array, we materialize it now.
if isinstance(selection, Array):
selection = _zarr_array_to_int_or_bool_array(selection)

fields, new_selection = pop_fields(selection)
new_selection = ensure_tuple(new_selection)
new_selection = replace_lists(new_selection)
return await self.array.get_orthogonal_selection(
cast(OrthogonalSelection, new_selection), fields=fields
)


@dataclass(frozen=True)
class BlockIndexer(Indexer):
dim_indexers: list[SliceDimIndexer]
Expand Down Expand Up @@ -1268,6 +1284,32 @@ def __setitem__(
raise VindexInvalidSelectionError(new_selection)


@dataclass(frozen=True)
class AsyncVIndex(Generic[T_ArrayMetadata]):
array: AsyncArray[T_ArrayMetadata]

# TODO: develop Array generic and move zarr.Array[np.intp] | zarr.Array[np.bool_] to ArrayOfIntOrBool
async def getitem(
self, selection: CoordinateSelection | MaskSelection | Array
) -> NDArrayLikeOrScalar:
# TODO deduplicate these internals with the sync version of getitem
# TODO requires solving this circular sync issue: https://github.com/zarr-developers/zarr-python/pull/3083#discussion_r2230737448
from zarr.core.array import Array

# if input is a Zarr array, we materialize it now.
if isinstance(selection, Array):
selection = _zarr_array_to_int_or_bool_array(selection)
fields, new_selection = pop_fields(selection)
new_selection = ensure_tuple(new_selection)
new_selection = replace_lists(new_selection)
if is_coordinate_selection(new_selection, self.array.shape):
return await self.array.get_coordinate_selection(new_selection, fields=fields)
elif is_mask_selection(new_selection, self.array.shape):
return await self.array.get_mask_selection(new_selection, fields=fields)
else:
raise VindexInvalidSelectionError(new_selection)


def check_fields(fields: Fields | None, dtype: np.dtype[Any]) -> np.dtype[Any]:
# early out
if fields is None:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from zarr.core.buffer import NDArrayLike, NDArrayLikeOrScalar, default_buffer_prototype
from zarr.core.chunk_grids import _auto_partition
from zarr.core.chunk_key_encodings import ChunkKeyEncodingParams
from zarr.core.common import JSON, ZarrFormat
from zarr.core.common import JSON, ZarrFormat, ceildiv
from zarr.core.dtype import (
DateTime64,
Float32,
Expand All @@ -59,7 +59,7 @@
from zarr.core.dtype.npy.common import NUMPY_ENDIANNESS_STR, endianness_from_numpy_str
from zarr.core.dtype.npy.string import UTF8Base
from zarr.core.group import AsyncGroup
from zarr.core.indexing import BasicIndexer, ceildiv
from zarr.core.indexing import BasicIndexer
from zarr.core.metadata.v2 import ArrayV2Metadata
from zarr.core.metadata.v3 import ArrayV3Metadata
from zarr.core.sync import sync
Expand Down
107 changes: 107 additions & 0 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1994,3 +1994,110 @@ def test_iter_chunk_regions():
assert_array_equal(a[region], np.ones_like(a[region]))
a[region] = 0
assert_array_equal(a[region], np.zeros_like(a[region]))


class TestAsync:
@pytest.mark.parametrize(
("indexer", "expected"),
[
# int
((0,), np.array([1, 2])),
((1,), np.array([3, 4])),
((0, 1), np.array(2)),
# slice
((slice(None),), np.array([[1, 2], [3, 4]])),
((slice(0, 1),), np.array([[1, 2]])),
((slice(1, 2),), np.array([[3, 4]])),
((slice(0, 2),), np.array([[1, 2], [3, 4]])),
((slice(0, 0),), np.empty(shape=(0, 2), dtype="i8")),
# ellipsis
((...,), np.array([[1, 2], [3, 4]])),
((0, ...), np.array([1, 2])),
((..., 0), np.array([1, 3])),
((0, 1, ...), np.array(2)),
# combined
((0, slice(None)), np.array([1, 2])),
((slice(None), 0), np.array([1, 3])),
((slice(None), slice(None)), np.array([[1, 2], [3, 4]])),
# array of ints
(([0]), np.array([[1, 2]])),
(([1]), np.array([[3, 4]])),
(([0], [1]), np.array(2)),
(([0, 1], [0]), np.array([[1], [3]])),
(([0, 1], [0, 1]), np.array([[1, 2], [3, 4]])),
# boolean array
(np.array([True, True]), np.array([[1, 2], [3, 4]])),
(np.array([True, False]), np.array([[1, 2]])),
(np.array([False, True]), np.array([[3, 4]])),
(np.array([False, False]), np.empty(shape=(0, 2), dtype="i8")),
],
)
@pytest.mark.asyncio
async def test_async_oindex(self, store, indexer, expected):
z = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
z[...] = np.array([[1, 2], [3, 4]])
async_zarr = z._async_array

result = await async_zarr.oindex.getitem(indexer)
assert_array_equal(result, expected)

@pytest.mark.asyncio
async def test_async_oindex_with_zarr_array(self, store):
z1 = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
z1[...] = np.array([[1, 2], [3, 4]])
async_zarr = z1._async_array

# create boolean zarr array to index with
z2 = zarr.create_array(
store=store, name="z2", shape=(2,), chunks=(1,), zarr_format=3, dtype="?"
)
z2[...] = np.array([True, False])

result = await async_zarr.oindex.getitem(z2)
expected = np.array([[1, 2]])
assert_array_equal(result, expected)

@pytest.mark.parametrize(
("indexer", "expected"),
[
(([0], [0]), np.array(1)),
(([0, 1], [0, 1]), np.array([1, 4])),
(np.array([[False, True], [False, True]]), np.array([2, 4])),
],
)
@pytest.mark.asyncio
async def test_async_vindex(self, store, indexer, expected):
z = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
z[...] = np.array([[1, 2], [3, 4]])
async_zarr = z._async_array

result = await async_zarr.vindex.getitem(indexer)
assert_array_equal(result, expected)

@pytest.mark.asyncio
async def test_async_vindex_with_zarr_array(self, store):
z1 = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
z1[...] = np.array([[1, 2], [3, 4]])
async_zarr = z1._async_array

# create boolean zarr array to index with
z2 = zarr.create_array(
store=store, name="z2", shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="?"
)
z2[...] = np.array([[False, True], [False, True]])

result = await async_zarr.vindex.getitem(z2)
expected = np.array([2, 4])
assert_array_equal(result, expected)

@pytest.mark.asyncio
async def test_async_invalid_indexer(self, store):
z = zarr.create_array(store=store, shape=(2, 2), chunks=(1, 1), zarr_format=3, dtype="i8")
z[...] = np.array([[1, 2], [3, 4]])
async_zarr = z._async_array

with pytest.raises(IndexError):
await async_zarr.vindex.getitem("invalid_indexer")

with pytest.raises(IndexError):
await async_zarr.oindex.getitem("invalid_indexer")
Loading